diff --git a/src/routers/search_router.py b/src/routers/search_router.py index aa7d8a87..ea51c12b 100644 --- a/src/routers/search_router.py +++ b/src/routers/search_router.py @@ -1,18 +1,18 @@ import abc from typing import TypeVar, Generic, Any, Type, Annotated -from elasticsearch import Elasticsearch from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from sqlalchemy.engine import Engine from sqlmodel import SQLModel, Session, select from starlette import status -from database.model.concept.concept import AIoDConcept from database.model.concept.aiod_entry import AIoDEntryRead -from database.model.resource_read_and_create import resource_read +from database.model.concept.concept import AIoDConcept from database.model.platform.platform import Platform +from database.model.resource_read_and_create import resource_read from .resource_router import _wrap_as_http_exception +from .search_routers.elasticsearch import ElasticsearchSingleton SORT = {"identifier": "asc"} LIMIT_MAX = 1000 @@ -33,9 +33,6 @@ class SearchRouter(Generic[RESOURCE], abc.ABC): Providing search functionality in ElasticSearch """ - def __init__(self, client: Elasticsearch): - self.client: Elasticsearch = client - @property @abc.abstractmethod def es_index(self) -> str: @@ -140,7 +137,7 @@ def search( # Launch search query # ----------------------------------------------------------------- - result = self.client.search( + result = ElasticsearchSingleton().client.search( index=self.es_index, query=query, from_=offset, size=limit, sort=SORT ) diff --git a/src/routers/search_routers/__init__.py b/src/routers/search_routers/__init__.py index e1980838..7dcceb07 100644 --- a/src/routers/search_routers/__init__.py +++ b/src/routers/search_routers/__init__.py @@ -1,6 +1,3 @@ -import os -from elasticsearch import Elasticsearch - from .search_router_datasets import SearchRouterDatasets from .search_router_events import SearchRouterEvents from .search_router_experiments import SearchRouterExperiments @@ -12,19 +9,14 @@ from .search_router_services import SearchRouterServices from ..search_router import SearchRouter -# Elasticsearch client -user = os.getenv("ES_USER") -pw = os.getenv("ES_PASSWORD") -es_client = Elasticsearch("http://elasticsearch:9200", basic_auth=(user, pw)) - router_list: list[SearchRouter] = [ - SearchRouterDatasets(client=es_client), - SearchRouterEvents(client=es_client), - SearchRouterExperiments(client=es_client), - SearchRouterMLModels(client=es_client), - SearchRouterNews(client=es_client), - SearchRouterOrganisations(client=es_client), - SearchRouterProjects(client=es_client), - SearchRouterPublications(client=es_client), - SearchRouterServices(client=es_client), + SearchRouterDatasets(), + SearchRouterEvents(), + SearchRouterExperiments(), + SearchRouterMLModels(), + SearchRouterNews(), + SearchRouterOrganisations(), + SearchRouterProjects(), + SearchRouterPublications(), + SearchRouterServices(), ] diff --git a/src/routers/search_routers/elasticsearch.py b/src/routers/search_routers/elasticsearch.py new file mode 100644 index 00000000..7e423e91 --- /dev/null +++ b/src/routers/search_routers/elasticsearch.py @@ -0,0 +1,24 @@ +import os + +from elasticsearch import Elasticsearch + + +class ElasticsearchSingleton: + """ + Making sure the Elasticsearch client is created only once, and easy to patch for + unittests. + """ + + __monostate = None + + def __init__(self): + if not ElasticsearchSingleton.__monostate: + ElasticsearchSingleton.__monostate = self.__dict__ + user = os.getenv("ES_USER", "") + pw = os.getenv("ES_PASSWORD", "") + self.client = Elasticsearch("http://elasticsearch:9200", basic_auth=(user, pw)) + else: + self.__dict__ = ElasticsearchSingleton.__monostate + + def patch(self, elasticsearch: Elasticsearch): + self.__monostate["client"] = elasticsearch # type:ignore diff --git a/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py b/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py index 24821fa0..c1df2f73 100644 --- a/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py +++ b/src/tests/connectors/huggingface/test_huggingface_dataset_connector.py @@ -2,48 +2,48 @@ import responses -# from connectors.huggingface.huggingface_dataset_connector import HuggingFaceDatasetConnector -# from connectors.resource_with_relations import ResourceWithRelations +from connectors.huggingface.huggingface_dataset_connector import HuggingFaceDatasetConnector +from connectors.resource_with_relations import ResourceWithRelations from tests.testutils.paths import path_test_resources HUGGINGFACE_URL = "https://datasets-server.huggingface.co" -# def test_fetch_all_happy_path(): -# ids_expected = { -# "0n1xus/codexglue", -# "04-07-22/wep-probes", -# "rotten_tomatoes", -# "acronym_identification", -# "air_dialogue", -# "bobbydylan/top2k", -# } -# connector = HuggingFaceDatasetConnector() -# with responses.RequestsMock() as mocked_requests: -# path_data_list = path_test_resources() / "connectors" / "huggingface" / "data_list.json" -# with open(path_data_list, "r") as f: -# response = json.load(f) -# mocked_requests.add( -# responses.GET, -# "https://huggingface.co/api/datasets?full=True", -# json=response, -# status=200, -# ) -# for dataset_id in ids_expected: -# mock_parquet(mocked_requests, dataset_id) -# resources_with_relations = list(connector.fetch()) -# -# assert len(resources_with_relations) == len(ids_expected) -# assert all(type(r) == ResourceWithRelations for r in resources_with_relations) -# -# datasets = [r.resource for r in resources_with_relations] -# assert {d.platform_resource_identifier for d in datasets} == ids_expected -# assert {d.name for d in datasets} == ids_expected -# assert all(d.date_published for d in datasets) -# assert all(d.aiod_entry for d in datasets) -# -# assert all(len(r.related_resources) in (1, 2) for r in resources_with_relations) -# assert all(len(r.related_resources["citation"]) == 1 for r in resources_with_relations[:5]) +def test_fetch_all_happy_path(): + ids_expected = { + "0n1xus/codexglue", + "04-07-22/wep-probes", + "rotten_tomatoes", + "acronym_identification", + "air_dialogue", + "bobbydylan/top2k", + } + connector = HuggingFaceDatasetConnector() + with responses.RequestsMock() as mocked_requests: + path_data_list = path_test_resources() / "connectors" / "huggingface" / "data_list.json" + with open(path_data_list, "r") as f: + response = json.load(f) + mocked_requests.add( + responses.GET, + "https://huggingface.co/api/datasets?full=True", + json=response, + status=200, + ) + for dataset_id in ids_expected: + mock_parquet(mocked_requests, dataset_id) + resources_with_relations = list(connector.fetch()) + + assert len(resources_with_relations) == len(ids_expected) + assert all(type(r) == ResourceWithRelations for r in resources_with_relations) + + datasets = [r.resource for r in resources_with_relations] + assert {d.platform_resource_identifier for d in datasets} == ids_expected + assert {d.name for d in datasets} == ids_expected + assert all(d.date_published for d in datasets) + assert all(d.aiod_entry for d in datasets) + + assert all(len(r.related_resources) in (1, 2) for r in resources_with_relations) + assert all(len(r.related_resources["citation"]) == 1 for r in resources_with_relations[:5]) def mock_parquet(mocked_requests: responses.RequestsMock, dataset_id: str): diff --git a/src/tests/routers/search_routers/test_search_routers.py b/src/tests/routers/search_routers/test_search_routers.py index fbf267a5..74915a45 100644 --- a/src/tests/routers/search_routers/test_search_routers.py +++ b/src/tests/routers/search_routers/test_search_routers.py @@ -2,14 +2,19 @@ import json from unittest.mock import Mock + +from elasticsearch import Elasticsearch from starlette.testclient import TestClient + +from routers.search_routers.elasticsearch import ElasticsearchSingleton from tests.testutils.paths import path_test_resources import routers.search_routers as sr def test_search_happy_path(client: TestClient): """Tests the search router""" - + mocked_elasticsearch = Elasticsearch("https://example.com:9200") + ElasticsearchSingleton().patch(mocked_elasticsearch) for search_router in sr.router_list: # Get the mocker results to test @@ -20,7 +25,7 @@ def test_search_happy_path(client: TestClient): mocked_results = json.load(f) # Mock and launch - search_router.client.search = Mock(return_value=mocked_results) + mocked_elasticsearch.search = Mock(return_value=mocked_results) search_service = f"/search/{search_router.resource_name_plural}/v1" params = {"search_query": "description", "get_all": False} response = client.get(search_service, params=params) @@ -37,7 +42,7 @@ def test_search_happy_path(client: TestClient): assert resource["aiod_entry"]["date_modified"] == "2023-09-01T00:00:00+00:00" # Test the extra fields - global_fields = set(["name", "plain", "html"]) + global_fields = {"name", "plain", "html"} extra_fields = list(search_router.match_fields ^ global_fields) for field in extra_fields: assert resource[field] @@ -45,6 +50,8 @@ def test_search_happy_path(client: TestClient): def test_search_bad_platform(client: TestClient): """Tests the search router bad platform error""" + mocked_elasticsearch = Elasticsearch("https://example.com:9200") + ElasticsearchSingleton().patch(mocked_elasticsearch) for search_router in sr.router_list: @@ -56,7 +63,7 @@ def test_search_bad_platform(client: TestClient): mocked_results = json.load(f) # Mock and launch - search_router.client.search = Mock(return_value=mocked_results) + mocked_elasticsearch.search = Mock(return_value=mocked_results) search_service = f"/search/{search_router.resource_name_plural}/v1" params = {"search_query": "description", "platforms": ["bad_platform"]} response = client.get(search_service, params=params) @@ -69,6 +76,8 @@ def test_search_bad_platform(client: TestClient): def test_search_bad_fields(client: TestClient): """Tests the search router bad fields error""" + mocked_elasticsearch = Elasticsearch("https://example.com:9200") + ElasticsearchSingleton().patch(mocked_elasticsearch) for search_router in sr.router_list: @@ -80,7 +89,7 @@ def test_search_bad_fields(client: TestClient): mocked_results = json.load(f) # Mock and launch - search_router.client.search = Mock(return_value=mocked_results) + mocked_elasticsearch.search = Mock(return_value=mocked_results) search_service = f"/search/{search_router.resource_name_plural}/v1" params = {"search_query": "description", "search_fields": ["bad_field"]} response = client.get(search_service, params=params) @@ -93,6 +102,8 @@ def test_search_bad_fields(client: TestClient): def test_search_bad_limit(client: TestClient): """Tests the search router bad fields error""" + mocked_elasticsearch = Elasticsearch("https://example.com:9200") + ElasticsearchSingleton().patch(mocked_elasticsearch) for search_router in sr.router_list: @@ -104,7 +115,7 @@ def test_search_bad_limit(client: TestClient): mocked_results = json.load(f) # Mock and launch - search_router.client.search = Mock(return_value=mocked_results) + mocked_elasticsearch.search = Mock(return_value=mocked_results) search_service = f"/search/{search_router.resource_name_plural}/v1" params = {"search_query": "description", "limit": 1001} response = client.get(search_service, params=params) @@ -117,6 +128,8 @@ def test_search_bad_limit(client: TestClient): def test_search_bad_offset(client: TestClient): """Tests the search router bad fields error""" + mocked_elasticsearch = Elasticsearch("https://example.com:9200") + ElasticsearchSingleton().patch(mocked_elasticsearch) for search_router in sr.router_list: @@ -128,7 +141,7 @@ def test_search_bad_offset(client: TestClient): mocked_results = json.load(f) # Mock and launch - search_router.client.search = Mock(return_value=mocked_results) + mocked_elasticsearch.search = Mock(return_value=mocked_results) search_service = f"/search/{search_router.resource_name_plural}/v1" params = {"search_query": "description", "offset": -1} response = client.get(search_service, params=params)