Skip to content

Commit

Permalink
Fixing unittests by making sure Elasticsearch instance can also be cr…
Browse files Browse the repository at this point in the history
…eated when ES_USER and ES_PASSWORD env vars are empty; used the style of PR #199
  • Loading branch information
josvandervelde committed Nov 20, 2023
1 parent 120f97a commit 3e5c446
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 68 deletions.
11 changes: 4 additions & 7 deletions src/routers/search_router.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import abc
from typing import TypeVar, Generic, Any, Type, Annotated

from elasticsearch import Elasticsearch
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.engine import Engine
from sqlmodel import SQLModel, Session, select
from starlette import status

from database.model.concept.concept import AIoDConcept
from database.model.concept.aiod_entry import AIoDEntryRead
from database.model.resource_read_and_create import resource_read
from database.model.concept.concept import AIoDConcept
from database.model.platform.platform import Platform
from database.model.resource_read_and_create import resource_read
from .resource_router import _wrap_as_http_exception
from .search_routers.elasticsearch import ElasticsearchSingleton

SORT = {"identifier": "asc"}
LIMIT_MAX = 1000
Expand All @@ -33,9 +33,6 @@ class SearchRouter(Generic[RESOURCE], abc.ABC):
Providing search functionality in ElasticSearch
"""

def __init__(self, client: Elasticsearch):
self.client: Elasticsearch = client

@property
@abc.abstractmethod
def es_index(self) -> str:
Expand Down Expand Up @@ -140,7 +137,7 @@ def search(
# Launch search query
# -----------------------------------------------------------------

result = self.client.search(
result = ElasticsearchSingleton().client.search(
index=self.es_index, query=query, from_=offset, size=limit, sort=SORT
)

Expand Down
26 changes: 9 additions & 17 deletions src/routers/search_routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
from elasticsearch import Elasticsearch

from .search_router_datasets import SearchRouterDatasets
from .search_router_events import SearchRouterEvents
from .search_router_experiments import SearchRouterExperiments
Expand All @@ -12,19 +9,14 @@
from .search_router_services import SearchRouterServices
from ..search_router import SearchRouter

# Elasticsearch client
user = os.getenv("ES_USER")
pw = os.getenv("ES_PASSWORD")
es_client = Elasticsearch("http://elasticsearch:9200", basic_auth=(user, pw))

router_list: list[SearchRouter] = [
SearchRouterDatasets(client=es_client),
SearchRouterEvents(client=es_client),
SearchRouterExperiments(client=es_client),
SearchRouterMLModels(client=es_client),
SearchRouterNews(client=es_client),
SearchRouterOrganisations(client=es_client),
SearchRouterProjects(client=es_client),
SearchRouterPublications(client=es_client),
SearchRouterServices(client=es_client),
SearchRouterDatasets(),
SearchRouterEvents(),
SearchRouterExperiments(),
SearchRouterMLModels(),
SearchRouterNews(),
SearchRouterOrganisations(),
SearchRouterProjects(),
SearchRouterPublications(),
SearchRouterServices(),
]
24 changes: 24 additions & 0 deletions src/routers/search_routers/elasticsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

from elasticsearch import Elasticsearch


class ElasticsearchSingleton:
"""
Making sure the Elasticsearch client is created only once, and easy to patch for
unittests.
"""

__monostate = None

def __init__(self):
if not ElasticsearchSingleton.__monostate:
ElasticsearchSingleton.__monostate = self.__dict__
user = os.getenv("ES_USER", "")
pw = os.getenv("ES_PASSWORD", "")
self.client = Elasticsearch("http://elasticsearch:9200", basic_auth=(user, pw))
else:
self.__dict__ = ElasticsearchSingleton.__monostate

def patch(self, elasticsearch: Elasticsearch):
self.__monostate["client"] = elasticsearch # type:ignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,48 @@

import responses

# from connectors.huggingface.huggingface_dataset_connector import HuggingFaceDatasetConnector
# from connectors.resource_with_relations import ResourceWithRelations
from connectors.huggingface.huggingface_dataset_connector import HuggingFaceDatasetConnector
from connectors.resource_with_relations import ResourceWithRelations
from tests.testutils.paths import path_test_resources

HUGGINGFACE_URL = "https://datasets-server.huggingface.co"


# def test_fetch_all_happy_path():
# ids_expected = {
# "0n1xus/codexglue",
# "04-07-22/wep-probes",
# "rotten_tomatoes",
# "acronym_identification",
# "air_dialogue",
# "bobbydylan/top2k",
# }
# connector = HuggingFaceDatasetConnector()
# with responses.RequestsMock() as mocked_requests:
# path_data_list = path_test_resources() / "connectors" / "huggingface" / "data_list.json"
# with open(path_data_list, "r") as f:
# response = json.load(f)
# mocked_requests.add(
# responses.GET,
# "https://huggingface.co/api/datasets?full=True",
# json=response,
# status=200,
# )
# for dataset_id in ids_expected:
# mock_parquet(mocked_requests, dataset_id)
# resources_with_relations = list(connector.fetch())
#
# assert len(resources_with_relations) == len(ids_expected)
# assert all(type(r) == ResourceWithRelations for r in resources_with_relations)
#
# datasets = [r.resource for r in resources_with_relations]
# assert {d.platform_resource_identifier for d in datasets} == ids_expected
# assert {d.name for d in datasets} == ids_expected
# assert all(d.date_published for d in datasets)
# assert all(d.aiod_entry for d in datasets)
#
# assert all(len(r.related_resources) in (1, 2) for r in resources_with_relations)
# assert all(len(r.related_resources["citation"]) == 1 for r in resources_with_relations[:5])
def test_fetch_all_happy_path():
ids_expected = {
"0n1xus/codexglue",
"04-07-22/wep-probes",
"rotten_tomatoes",
"acronym_identification",
"air_dialogue",
"bobbydylan/top2k",
}
connector = HuggingFaceDatasetConnector()
with responses.RequestsMock() as mocked_requests:
path_data_list = path_test_resources() / "connectors" / "huggingface" / "data_list.json"
with open(path_data_list, "r") as f:
response = json.load(f)
mocked_requests.add(
responses.GET,
"https://huggingface.co/api/datasets?full=True",
json=response,
status=200,
)
for dataset_id in ids_expected:
mock_parquet(mocked_requests, dataset_id)
resources_with_relations = list(connector.fetch())

assert len(resources_with_relations) == len(ids_expected)
assert all(type(r) == ResourceWithRelations for r in resources_with_relations)

datasets = [r.resource for r in resources_with_relations]
assert {d.platform_resource_identifier for d in datasets} == ids_expected
assert {d.name for d in datasets} == ids_expected
assert all(d.date_published for d in datasets)
assert all(d.aiod_entry for d in datasets)

assert all(len(r.related_resources) in (1, 2) for r in resources_with_relations)
assert all(len(r.related_resources["citation"]) == 1 for r in resources_with_relations[:5])


def mock_parquet(mocked_requests: responses.RequestsMock, dataset_id: str):
Expand Down
27 changes: 20 additions & 7 deletions src/tests/routers/search_routers/test_search_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import json

from unittest.mock import Mock

from elasticsearch import Elasticsearch
from starlette.testclient import TestClient

from routers.search_routers.elasticsearch import ElasticsearchSingleton
from tests.testutils.paths import path_test_resources
import routers.search_routers as sr


def test_search_happy_path(client: TestClient):
"""Tests the search router"""

mocked_elasticsearch = Elasticsearch("https://example.com:9200")
ElasticsearchSingleton().patch(mocked_elasticsearch)
for search_router in sr.router_list:

# Get the mocker results to test
Expand All @@ -20,7 +25,7 @@ def test_search_happy_path(client: TestClient):
mocked_results = json.load(f)

# Mock and launch
search_router.client.search = Mock(return_value=mocked_results)
mocked_elasticsearch.search = Mock(return_value=mocked_results)
search_service = f"/search/{search_router.resource_name_plural}/v1"
params = {"search_query": "description", "get_all": False}
response = client.get(search_service, params=params)
Expand All @@ -37,14 +42,16 @@ def test_search_happy_path(client: TestClient):
assert resource["aiod_entry"]["date_modified"] == "2023-09-01T00:00:00+00:00"

# Test the extra fields
global_fields = set(["name", "plain", "html"])
global_fields = {"name", "plain", "html"}
extra_fields = list(search_router.match_fields ^ global_fields)
for field in extra_fields:
assert resource[field]


def test_search_bad_platform(client: TestClient):
"""Tests the search router bad platform error"""
mocked_elasticsearch = Elasticsearch("https://example.com:9200")
ElasticsearchSingleton().patch(mocked_elasticsearch)

for search_router in sr.router_list:

Expand All @@ -56,7 +63,7 @@ def test_search_bad_platform(client: TestClient):
mocked_results = json.load(f)

# Mock and launch
search_router.client.search = Mock(return_value=mocked_results)
mocked_elasticsearch.search = Mock(return_value=mocked_results)
search_service = f"/search/{search_router.resource_name_plural}/v1"
params = {"search_query": "description", "platforms": ["bad_platform"]}
response = client.get(search_service, params=params)
Expand All @@ -69,6 +76,8 @@ def test_search_bad_platform(client: TestClient):

def test_search_bad_fields(client: TestClient):
"""Tests the search router bad fields error"""
mocked_elasticsearch = Elasticsearch("https://example.com:9200")
ElasticsearchSingleton().patch(mocked_elasticsearch)

for search_router in sr.router_list:

Expand All @@ -80,7 +89,7 @@ def test_search_bad_fields(client: TestClient):
mocked_results = json.load(f)

# Mock and launch
search_router.client.search = Mock(return_value=mocked_results)
mocked_elasticsearch.search = Mock(return_value=mocked_results)
search_service = f"/search/{search_router.resource_name_plural}/v1"
params = {"search_query": "description", "search_fields": ["bad_field"]}
response = client.get(search_service, params=params)
Expand All @@ -93,6 +102,8 @@ def test_search_bad_fields(client: TestClient):

def test_search_bad_limit(client: TestClient):
"""Tests the search router bad fields error"""
mocked_elasticsearch = Elasticsearch("https://example.com:9200")
ElasticsearchSingleton().patch(mocked_elasticsearch)

for search_router in sr.router_list:

Expand All @@ -104,7 +115,7 @@ def test_search_bad_limit(client: TestClient):
mocked_results = json.load(f)

# Mock and launch
search_router.client.search = Mock(return_value=mocked_results)
mocked_elasticsearch.search = Mock(return_value=mocked_results)
search_service = f"/search/{search_router.resource_name_plural}/v1"
params = {"search_query": "description", "limit": 1001}
response = client.get(search_service, params=params)
Expand All @@ -117,6 +128,8 @@ def test_search_bad_limit(client: TestClient):

def test_search_bad_offset(client: TestClient):
"""Tests the search router bad fields error"""
mocked_elasticsearch = Elasticsearch("https://example.com:9200")
ElasticsearchSingleton().patch(mocked_elasticsearch)

for search_router in sr.router_list:

Expand All @@ -128,7 +141,7 @@ def test_search_bad_offset(client: TestClient):
mocked_results = json.load(f)

# Mock and launch
search_router.client.search = Mock(return_value=mocked_results)
mocked_elasticsearch.search = Mock(return_value=mocked_results)
search_service = f"/search/{search_router.resource_name_plural}/v1"
params = {"search_query": "description", "offset": -1}
response = client.get(search_service, params=params)
Expand Down

0 comments on commit 3e5c446

Please sign in to comment.