diff --git a/api/barriers/models.py b/api/barriers/models.py index 074d922fe..debc08ab6 100644 --- a/api/barriers/models.py +++ b/api/barriers/models.py @@ -1640,10 +1640,9 @@ def vector_search(self, queryset, name, value): SIMILARITY_THRESHOLD, ) - if handler_manager.manager is None: - handler_manager.init() + related_barriers = handler_manager.get_or_init() - barrier_scores = handler_manager.manager.get_similar_barriers_searched( + barrier_scores = related_barriers.get_similar_barriers_searched( search_term=value, similarity_threshold=SIMILARITY_THRESHOLD, quantity=SIMILAR_BARRIERS_LIMIT, diff --git a/api/barriers/signals/handlers.py b/api/barriers/signals/handlers.py index 6ae785415..18504e1c5 100644 --- a/api/barriers/signals/handlers.py +++ b/api/barriers/signals/handlers.py @@ -21,9 +21,8 @@ send_top_priority_notification, ) from api.metadata.constants import TOP_PRIORITY_BARRIER_STATUS -from api.related_barriers import manager -from api.related_barriers.constants import BarrierEntry from api.related_barriers.manager import BARRIER_UPDATE_FIELDS +from api.related_barriers.tasks import update_related_barrier logger = logging.getLogger(__name__) @@ -228,15 +227,4 @@ def related_barrier_update_embeddings(sender, instance, *args, **kwargs): ) if changed and not current_barrier_object.draft: - if not manager.manager: - manager.init() - try: - manager.manager.update_barrier( - BarrierEntry( - id=str(current_barrier_object.id), - barrier_corpus=manager.barrier_to_corpus(instance), - ) - ) - except Exception as e: - # We don't want barrier embedding updates to break worker so just log error - logger.critical(str(e)) + update_related_barrier(barrier_id=str(instance.pk)) diff --git a/api/related_barriers/management/commands/related_barriers.py b/api/related_barriers/management/commands/related_barriers.py new file mode 100644 index 000000000..ef87e6161 --- /dev/null +++ b/api/related_barriers/management/commands/related_barriers.py @@ -0,0 +1,48 @@ +import logging +import time + +from django.core.management import BaseCommand + +from api.related_barriers import manager + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Related Barrier service management" + + def add_arguments(self, parser): + parser.add_argument("--flush", type=bool, help="Flush cache") + parser.add_argument("--reindex", type=bool, help="Reindex cache") + parser.add_argument( + "--stats", type=bool, help="Get Related Barrier cache stats" + ) + + def handle(self, *args, **options): + flush = options["flush"] + reindex = options["reindex"] + stats = options["stats"] + + rb_manager = manager.RelatedBarrierManager() + + barrier_count = len(rb_manager.get_barrier_ids()) + + if stats: + logger.info(f"Barrier Count: {barrier_count}") + + if reindex: + logger.info(f"Reindexing {barrier_count} barriers") + data = manager.get_data() + + rb_manager.flush() + + s = time.time() + rb_manager.set_data(data) + end = time.time() + logger.info("Training time: ", end - s) + return + + if flush: + logger.info("Flushing") + rb_manager.flush() + return diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index 52afae1d9..f20c630a1 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional import numpy +import torch from django.core.cache import cache from sentence_transformers import SentenceTransformer, util @@ -29,8 +30,7 @@ def wrapper(*args, **kwargs): @timing def get_transformer(): - # SentenceTransformer("all-MiniLM-L6-v2") - return SentenceTransformer("paraphrase-MiniLM-L3-v2") + return SentenceTransformer("all-MiniLM-L6-v2") BARRIER_UPDATE_FIELDS: List[str] = ["title", "summary"] @@ -164,7 +164,7 @@ def remove_barrier(self, barrier: BarrierEntry, barrier_ids=None) -> None: @timing def update_barrier(self, barrier: BarrierEntry) -> None: logger.info(f"(Related Barriers): update_barrier {barrier.id}") - barrier_ids = manager.get_barrier_ids() + barrier_ids = self.get_barrier_ids() if barrier.id in barrier_ids: self.remove_barrier(barrier, barrier_ids) self.add_barrier(barrier, barrier_ids) @@ -200,7 +200,8 @@ def get_similar_barriers( for k, v in barrier_scores.items() if v > similarity_threshold and k != barrier.id } - return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + return [(r[0], torch.round(torch.tensor(r[1]), decimals=4)) for r in results] @timing def get_similar_barriers_searched( @@ -222,7 +223,7 @@ def get_similar_barriers_searched( logger.info("(Related Barriers): get_similar_barriers_searched") if not (barrier_ids := self.get_barrier_ids()): - self.set_data(get_data_2()) + self.set_data(get_data()) barrier_ids = self.get_barrier_ids() or [] if not barrier_ids: @@ -249,87 +250,28 @@ def get_similar_barriers_searched( for k, v in barrier_scores.items() if v > similarity_threshold and k != embedded_index } - return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] - - -manager: Optional[RelatedBarrierManager] = None + results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + return [(r[0], torch.round(r[1], decimals=4)) for r in results] def get_data() -> List[Dict]: from api.barriers.models import Barrier - return [ - {"id": barrier.id, "barrier_corpus": f"{barrier.title} . {barrier.summary}"} - for barrier in Barrier.objects.filter(archived=False).exclude(draft=True) - ] - - -def get_data_2() -> List[Dict]: - from api.barriers.models import Barrier - data_dictionary = [] - barriers = Barrier.objects.filter(archived=False).exclude(draft=True) + barriers = ( + Barrier.objects.prefetch_related("interactions_documents") + .filter(archived=False) + .exclude(draft=True) + ) for barrier in barriers: - corpus_object = {} - - companies_affected_list = "" - if barrier.companies: - companies_affected_list = ", ".join( - [company["name"] for company in barrier.companies] - ) - - other_organisations_affected_list = "" - if barrier.related_organisations: - other_organisations_affected_list = ", ".join( - [company["name"] for company in barrier.related_organisations] - ) - - notes_text_list = ", ".join( - [note.text for note in barrier.interactions_documents.all()] - ) - - sectors_list = [ - get_sector(str(sector_id))["name"] for sector_id in barrier.sectors - ] - sectors_list.append(get_sector(barrier.main_sector)["name"]) - sectors_text = ", ".join(sectors_list) - - overseas_region_text = get_barriers_overseas_region( - barrier.country, barrier.trading_bloc + data_dictionary.append( + {"id": str(barrier.id), "barrier_corpus": barrier_to_corpus(barrier)} ) - estimated_resolution_date_text = "" - if barrier.estimated_resolution_date: - date = barrier.estimated_resolution_date.strftime("%d-%m-%Y") - estimated_resolution_date_text = f"Estimated to be resolved on {date}." - - corpus_object["id"] = barrier.id - corpus_object["barrier_corpus"] = ( - f"{barrier.code}. " - f"{barrier.title}. " - f"{barrier.summary}. " - f"{sectors_text} " - f"{barrier.country_name}. " - f"{overseas_region_text}. " - f"{companies_affected_list} " - f"{other_organisations_affected_list} " - f"{notes_text_list}. " - f"{barrier.status_summary}. " - f"{estimated_resolution_date_text}. " - f"{barrier.export_description}." - ) - - data_dictionary.append(corpus_object) - return data_dictionary -def init(): - global manager - - if manager: - raise Exception("Related Barrier Manager already set") - +def get_or_init() -> RelatedBarrierManager: manager = RelatedBarrierManager() if not manager.get_barrier_ids(): @@ -337,6 +279,50 @@ def init(): data = get_data() manager.set_data(data) + return manager + def barrier_to_corpus(barrier) -> str: - return barrier.title + ". " + barrier.summary + companies_affected_list = "" + if barrier.companies: + companies_affected_list = ", ".join( + [company["name"] for company in barrier.companies] + ) + + other_organisations_affected_list = "" + if barrier.related_organisations: + other_organisations_affected_list = ", ".join( + [company["name"] for company in barrier.related_organisations] + ) + + notes_text_list = ", ".join( + [note.text for note in barrier.interactions_documents.all()] + ) + + sectors_list = [ + get_sector(str(sector_id))["name"] + for sector_id in barrier.sectors + if get_sector(str(sector_id)) + ] + main_sector = get_sector(barrier.main_sector) + if main_sector: + sectors_list.append(main_sector["name"]) + sectors_text = ", ".join(sectors_list) + + overseas_region_text = get_barriers_overseas_region( + barrier.country, barrier.trading_bloc + ) + + estimated_resolution_date_text = "" + if barrier.estimated_resolution_date: + date = barrier.estimated_resolution_date.strftime("%d-%m-%Y") + estimated_resolution_date_text = f"Estimated to be resolved on {date}." + + return ( + f"{barrier.title}. {barrier.summary}. " + f"{sectors_text} {barrier.country_name}. " + f"{overseas_region_text}. {companies_affected_list} " + f"{other_organisations_affected_list} {notes_text_list}. " + f"{barrier.status_summary}. {estimated_resolution_date_text}. " + f"{barrier.export_description}." + ) diff --git a/api/related_barriers/tasks.py b/api/related_barriers/tasks.py new file mode 100644 index 000000000..129ae9def --- /dev/null +++ b/api/related_barriers/tasks.py @@ -0,0 +1,33 @@ +import logging + +from celery import shared_task +from django.core.management import call_command + +from api.barriers.models import Barrier +from api.related_barriers import manager +from api.related_barriers.constants import BarrierEntry + +logger = logging.getLogger(__name__) + + +def update_related_barrier(barrier_id: str): + logger.info(f"Updating related barrier embeddings for: {barrier_id}") + related_barriers = manager.get_or_init() + try: + related_barriers.update_barrier( + BarrierEntry( + id=barrier_id, + barrier_corpus=manager.barrier_to_corpus( + Barrier.objects.get(pk=barrier_id) + ), + ) + ) + except Exception as e: + # We don't want barrier embedding updates to break worker so just log error + logger.critical(str(e)) + + +@shared_task +def reindex_related_barriers(): + """Schedule daily task to reindex""" + call_command("related_barriers", "--reindex", "1") diff --git a/api/related_barriers/urls.py b/api/related_barriers/urls.py index 867498063..6e226f0fb 100644 --- a/api/related_barriers/urls.py +++ b/api/related_barriers/urls.py @@ -1,7 +1,7 @@ from django.urls import path from rest_framework.routers import DefaultRouter -from api.related_barriers.views import related_barriers, related_barriers_search +from api.related_barriers.views import related_barriers_search, related_barriers_view app_name = "related_barriers" @@ -10,7 +10,9 @@ urlpatterns = router.urls + [ path( - "barriers//related-barriers", related_barriers, name="related-barriers" + "barriers//related-barriers", + related_barriers_view, + name="related-barriers", ), path("related-barriers", related_barriers_search, name="related-barriers-search"), ] diff --git a/api/related_barriers/views.py b/api/related_barriers/views.py index b0e6c646e..49cf92ff1 100644 --- a/api/related_barriers/views.py +++ b/api/related_barriers/views.py @@ -19,17 +19,16 @@ @api_view(["GET"]) -def related_barriers(request, pk) -> Response: +def related_barriers_view(request, pk) -> Response: """ Return a list of related barriers """ logger.info(f"Getting related barriers for {pk}") barrier = get_object_or_404(Barrier, pk=pk) - if manager.manager is None: - manager.init() + related_barriers = manager.get_or_init() - barrier_scores = manager.manager.get_similar_barriers( + barrier_scores = related_barriers.get_similar_barriers( barrier=BarrierEntry( id=str(barrier.id), barrier_corpus=manager.barrier_to_corpus(barrier), @@ -60,10 +59,9 @@ def related_barriers_search(request) -> Response: serializer = SearchRequest(data=request.GET) serializer.is_valid(raise_exception=True) - if manager.manager is None: - manager.init() + related_barriers = manager.get_or_init() - barrier_scores = manager.manager.get_similar_barriers_searched( + barrier_scores = related_barriers.get_similar_barriers_searched( search_term=serializer.data["search_term"], similarity_threshold=SIMILARITY_THRESHOLD, quantity=SIMILAR_BARRIERS_LIMIT, diff --git a/config/settings/base.py b/config/settings/base.py index 5691ce711..5795d7cad 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -417,6 +417,12 @@ CELERY_BEAT_SCHEDULE = {} if not DEBUG: + # Runs daily at midnight + CELERY_BEAT_SCHEDULE["reindex_related_barriers"] = { + "task": "api.related_barriers.tasks.reindex_related_barriers", + "schedule": crontab(minute=0, hour=0), + } + # Runs daily at 6am CELERY_BEAT_SCHEDULE["send_notification_emails"] = { "task": "api.user.tasks.send_notification_emails", diff --git a/tests/barriers/signals/test_handlers.py b/tests/barriers/signals/test_handlers.py index b72507cee..b4b8d392c 100644 --- a/tests/barriers/signals/test_handlers.py +++ b/tests/barriers/signals/test_handlers.py @@ -9,8 +9,10 @@ def test_related_barrier_handler(): barrier = BarrierFactory() - with mock.patch("api.barriers.signals.handlers.manager") as mock_manager: + with mock.patch( + "api.barriers.signals.handlers.update_related_barrier" + ) as mock_update_related_barrier: barrier.title = "New Title" barrier.save() - assert mock_manager.manager.update_barrier.call_count == 1 + mock_update_related_barrier.assert_called_once_with(barrier_id=str(barrier.pk)) diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py new file mode 100644 index 000000000..9b956e1cb --- /dev/null +++ b/tests/related_barriers/test_baseline_performance.py @@ -0,0 +1,181 @@ +from typing import Dict, List, Tuple + +import pytest +import torch +from django.db.models import signals +from factory.django import mute_signals +from sentence_transformers import SentenceTransformer + +from api.barriers.models import Barrier +from api.related_barriers.constants import BarrierEntry +from api.related_barriers.manager import RelatedBarrierManager, barrier_to_corpus +from tests.barriers.factories import BarrierFactory + +pytestmark = [pytest.mark.django_db] + + +@pytest.fixture +def manager(settings): + # Loads transformer model + settings.CACHES = { + "default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"} + } + return RelatedBarrierManager() + + +@pytest.fixture +def TEST_CASE_1() -> Tuple[List[Dict], List[Barrier]]: + with mute_signals(signals.pre_save): + b1 = BarrierFactory(title="The weather is lovely today", summary="") + b2 = BarrierFactory(title="It's so sunny outside!", summary="") + b3 = BarrierFactory(title="He drove to the stadium.", summary="") + + data = ( + Barrier.objects.filter(archived=False) + .exclude(draft=True) + .values("id", "title", "summary") + ) + + return [ + {"id": str(d["id"]), "barrier_corpus": f"{d['title']} . {d['summary']}"} + for d in data + ], [b1, b2, b3] + + +@pytest.fixture +def TEST_CASE_2() -> Tuple[List[Dict], List[Barrier]]: + corpus = [ + "A man is eating food.", + "A man is eating a piece of bread.", + "The girl is carrying a baby.", + "A man is riding a horse.", + "A woman is playing violin.", + "Two men pushed carts through the woods.", + "A man is riding a white horse on an enclosed ground.", + "A monkey is playing drums.", + "A cheetah is running behind its prey.", + ] + with mute_signals(signals.pre_save): + barriers = [BarrierFactory(title=title) for title in corpus] + + data = ( + Barrier.objects.filter(archived=False).exclude(draft=True).values("id", "title") + ) + + return [ + {"id": str(d["id"]), "barrier_corpus": f"{d['title']}"} for d in data + ], barriers + + +@pytest.fixture +def TEST_CASE_3() -> Tuple[List[Dict], List[Barrier]]: + barriers = [ + BarrierFactory( + title="Translating Trademark for Whisky Products in Quebec", + summary=( + "Bill 96, published by the Quebec Government years ago, changed " + "general regulation regarding French language translations. They " + "ask business to translate everything into French, including trademarks " + "which causes particular difficulties. SWA are waiting for a regulation " + "on the implementation of this as there is a lack of clarity where in " + "the draft regulation they stated that trademarks are exempted but if it " + "includes a general term or description of the product, this should be translated.", + ), + ), + BarrierFactory( + title="Offshore wind: local content; skills & workforce engagement", + summary=( + "Taiwan requested local content for offshore wind project but Taiwanese " + "firms currently lack capabilities and experience, which combined with " + "weak domestic competition, has in many cases resulted in inflated costs " + "and unnecessary project delays. In 2021, Taiwan Energy Administration " + "allowed some additional flexibility as developers only needed to source " + "60% of items locally, compared to 100% previously. This removes the " + "absolute market access barrier but still not enough.", + ), + ), + ] + + return [ + {"id": str(b.id), "barrier_corpus": f"{b.title} . {b.summary}"} + for b in barriers + ], barriers + + +def test_transformer_loads(manager): + assert isinstance(manager.model, SentenceTransformer) + + +def test_search_case_1(manager, TEST_CASE_1): + training_data, barriers = TEST_CASE_1 + + b1, b2, b3 = barriers + + manager.set_data(training_data) + + assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) + + results1 = manager.get_similar_barriers_searched( + search_term="The weather", + similarity_threshold=0, + quantity=5, + ) + results2 = manager.get_similar_barriers_searched( + search_term="A car at the stadium", + similarity_threshold=0, + quantity=5, + ) + + assert results1 == [ + (str(b3.id), torch.tensor(0.1963)), + (str(b2.id), torch.tensor(0.5966)), + (str(b1.id), torch.tensor(0.6943)), + ] + assert results2 == [ + (str(b1.id), torch.tensor(0.1044)), + (str(b2.id), torch.tensor(0.1335)), + (str(b3.id), torch.tensor(0.6433)), + ] + + +def test_search_case_2(manager, TEST_CASE_2): + training_data, barriers = TEST_CASE_2 + + manager.set_data(training_data) + + assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) + + results1 = manager.get_similar_barriers_searched( + search_term="A man is eating pasta.", + similarity_threshold=0, + quantity=20, + ) + + assert results1 == [ + (str(barriers[4].id), torch.tensor(0.0336)), + (str(barriers[7].id), torch.tensor(0.0819)), + (str(barriers[8].id), torch.tensor(0.0980)), + (str(barriers[6].id), torch.tensor(0.1047)), + (str(barriers[3].id), torch.tensor(0.1889)), + (str(barriers[1].id), torch.tensor(0.5272)), + (str(barriers[0].id), torch.tensor(0.7035)), + ] + + +def test_related_barriers_1(manager, TEST_CASE_3): + training_data, barriers = TEST_CASE_3 + + manager.set_data(training_data) + + assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) + + barrier_entry = BarrierEntry( + id=str(barriers[0].id), barrier_corpus=barrier_to_corpus(barriers[0]) + ) + results1 = manager.get_similar_barriers( + barrier_entry, + similarity_threshold=0, + quantity=20, + ) + + assert results1 == [(str(barriers[1].id), torch.tensor(0.1556))]