From 25f81acc110a621f2d443f4fee4d6ae0f47db9e2 Mon Sep 17 00:00:00 2001 From: abarolo Date: Wed, 25 Sep 2024 13:46:22 +0100 Subject: [PATCH 01/17] init --- api/related_barriers/manager.py | 4 +- config/settings/test.py | 2 +- tests/related_barriers/conftest.py | 2 +- .../test_baseline_performance.py | 122 ++++++++++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 tests/related_barriers/test_baseline_performance.py diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index 363599fe5..ef0b36d79 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -29,8 +29,8 @@ def wrapper(*args, **kwargs): @timing def get_transformer(): - # SentenceTransformer("all-MiniLM-L6-v2") - return SentenceTransformer("paraphrase-MiniLM-L3-v2") + return SentenceTransformer("all-MiniLM-L6-v2") + # return SentenceTransformer("paraphrase-MiniLM-L3-v2") BARRIER_UPDATE_FIELDS: List[str] = ["title", "summary"] diff --git a/config/settings/test.py b/config/settings/test.py index cff42b499..38e61cd14 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -15,7 +15,7 @@ ] DOCUMENT_BUCKET = "test-bucket" -CACHES = {"default": {"BACKEND": "django.core.cache.backends.dummy.DummyCache"}} +CACHES = {"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}} # Stop WhiteNoise emitting warnings when running tests without running collectstatic first WHITENOISE_AUTOREFRESH = True diff --git a/tests/related_barriers/conftest.py b/tests/related_barriers/conftest.py index 19e17dd35..0d8694bbb 100644 --- a/tests/related_barriers/conftest.py +++ b/tests/related_barriers/conftest.py @@ -5,7 +5,7 @@ from django.db.models.functions import Concat from api.barriers.models import Barrier -from api.related_barriers.manager import RelatedBarrierManager +from api.related_barriers.manager import RelatedBarrierManager, get_transformer from tests.barriers.factories import BarrierFactory diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py new file mode 100644 index 000000000..a3ed3a71c --- /dev/null +++ b/tests/related_barriers/test_baseline_performance.py @@ -0,0 +1,122 @@ +from typing import Tuple, List, Dict + +import pytest +from django.db.models import Value as V, signals +from django.db.models.functions import Concat +from django.forms import CharField +from factory.django import mute_signals +from api.barriers.models import Barrier +from api.related_barriers.manager import RelatedBarrierManager +from sentence_transformers import SentenceTransformer +import torch +from tests.barriers.factories import BarrierFactory +from django.core.cache import cache + +pytestmark = [pytest.mark.django_db] + + +@pytest.fixture +def manager(): + # Loads transformer model + return RelatedBarrierManager() + + +@pytest.fixture +def TEST_CASE_1() -> Tuple[List[Dict], List[Barrier]]: + with mute_signals(signals.pre_save): + b1 = BarrierFactory(title="The weather is lovely today", summary="") + b2 = BarrierFactory(title="It's so sunny outside!", summary="") + b3 = BarrierFactory(title="He drove to the stadium.", summary="") + + data = ( + Barrier.objects.filter(archived=False) + .exclude(draft=True) + .values("id", "title", "summary") + ) + + return [{"id": str(d["id"]), "barrier_corpus": f"{d['title']} . {d['summary']}"} for d in data], [b1, b2, b3] + + +@pytest.fixture +def TEST_CASE_2() -> Tuple[List[Dict], List[Barrier]]: + corpus = [ + "A man is eating food.", + "A man is eating a piece of bread.", + "The girl is carrying a baby.", + "A man is riding a horse.", + "A woman is playing violin.", + "Two men pushed carts through the woods.", + "A man is riding a white horse on an enclosed ground.", + "A monkey is playing drums.", + "A cheetah is running behind its prey.", + ] + with mute_signals(signals.pre_save): + barriers = [BarrierFactory(title=title) for title in corpus] + + data = ( + Barrier.objects.filter(archived=False) + .exclude(draft=True) + .values("id", "title") + ) + + return [{"id": str(d["id"]), "barrier_corpus": f"{d['title']}"} for d in data], barriers + + +def test_transformer_loads(manager): + assert isinstance(manager.model, SentenceTransformer) + + +def test_case_1(manager, TEST_CASE_1): + training_data, barriers = TEST_CASE_1 + + b1, b2, b3 = barriers + + manager.set_data(training_data) + + assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) + + results1 = manager.get_similar_barriers_searched( + search_term="The weather", + similarity_threshold=0, + quantity=5, + ) + results2 = manager.get_similar_barriers_searched( + search_term="A car at the stadium", + similarity_threshold=0, + quantity=5, + ) + + assert results1 == [ + (str(b3.id), torch.tensor(0.02711891382932663)), + (str(b2.id), torch.tensor(0.45128241181373596)), + (str(b1.id), torch.tensor(0.6327418088912964)), + ] + assert results2 == [ + (str(b1.id), torch.tensor(0.1231116801500320)), + (str(b2.id), torch.tensor(0.223122775554656986)), + (str(b3.id), torch.tensor(0.5845098495483398)), + ] + + +def test_case_2(manager, TEST_CASE_2): + training_data, barriers = TEST_CASE_2 + + manager.set_data(training_data) + + assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) + + results1 = manager.get_similar_barriers_searched( + search_term="A man is eating pasta.", + similarity_threshold=0, + quantity=20, + ) + + assert results1 == [ + (str(barriers[4].id), torch.tensor(0.03359394893050194)), + (str(barriers[7].id), torch.tensor(0.08189050108194351)), + (str(barriers[8].id), torch.tensor(0.09803038835525513)), + (str(barriers[6].id), torch.tensor(0.10469925403594971)), + (str(barriers[3].id), torch.tensor(0.18889553844928741)), + (str(barriers[1].id), torch.tensor(0.5271985530853271)), + (str(barriers[0].id), torch.tensor(0.7035487294197083)) + ] From 4b420216986a19f7625786dc90732aa4400661c7 Mon Sep 17 00:00:00 2001 From: abarolo Date: Wed, 25 Sep 2024 13:56:22 +0100 Subject: [PATCH 02/17] format --- tests/related_barriers/conftest.py | 2 +- .../test_baseline_performance.py | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/related_barriers/conftest.py b/tests/related_barriers/conftest.py index 0d8694bbb..19e17dd35 100644 --- a/tests/related_barriers/conftest.py +++ b/tests/related_barriers/conftest.py @@ -5,7 +5,7 @@ from django.db.models.functions import Concat from api.barriers.models import Barrier -from api.related_barriers.manager import RelatedBarrierManager, get_transformer +from api.related_barriers.manager import RelatedBarrierManager from tests.barriers.factories import BarrierFactory diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index a3ed3a71c..32caac1f4 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -1,16 +1,14 @@ -from typing import Tuple, List, Dict +from typing import Dict, List, Tuple import pytest -from django.db.models import Value as V, signals -from django.db.models.functions import Concat -from django.forms import CharField +import torch +from django.db.models import signals from factory.django import mute_signals +from sentence_transformers import SentenceTransformer + from api.barriers.models import Barrier from api.related_barriers.manager import RelatedBarrierManager -from sentence_transformers import SentenceTransformer -import torch from tests.barriers.factories import BarrierFactory -from django.core.cache import cache pytestmark = [pytest.mark.django_db] @@ -34,7 +32,10 @@ def TEST_CASE_1() -> Tuple[List[Dict], List[Barrier]]: .values("id", "title", "summary") ) - return [{"id": str(d["id"]), "barrier_corpus": f"{d['title']} . {d['summary']}"} for d in data], [b1, b2, b3] + return [ + {"id": str(d["id"]), "barrier_corpus": f"{d['title']} . {d['summary']}"} + for d in data + ], [b1, b2, b3] @pytest.fixture @@ -54,12 +55,12 @@ def TEST_CASE_2() -> Tuple[List[Dict], List[Barrier]]: barriers = [BarrierFactory(title=title) for title in corpus] data = ( - Barrier.objects.filter(archived=False) - .exclude(draft=True) - .values("id", "title") + Barrier.objects.filter(archived=False).exclude(draft=True).values("id", "title") ) - return [{"id": str(d["id"]), "barrier_corpus": f"{d['title']}"} for d in data], barriers + return [ + {"id": str(d["id"]), "barrier_corpus": f"{d['title']}"} for d in data + ], barriers def test_transformer_loads(manager): @@ -118,5 +119,5 @@ def test_case_2(manager, TEST_CASE_2): (str(barriers[6].id), torch.tensor(0.10469925403594971)), (str(barriers[3].id), torch.tensor(0.18889553844928741)), (str(barriers[1].id), torch.tensor(0.5271985530853271)), - (str(barriers[0].id), torch.tensor(0.7035487294197083)) + (str(barriers[0].id), torch.tensor(0.7035487294197083)), ] From c4b1d9f57b8513a0ccc7afb1d4e168a20475d587 Mon Sep 17 00:00:00 2001 From: abarolo Date: Wed, 25 Sep 2024 14:31:01 +0100 Subject: [PATCH 03/17] fix test --- tests/related_barriers/test_baseline_performance.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 32caac1f4..438594128 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -88,14 +88,14 @@ def test_case_1(manager, TEST_CASE_1): ) assert results1 == [ - (str(b3.id), torch.tensor(0.02711891382932663)), - (str(b2.id), torch.tensor(0.45128241181373596)), - (str(b1.id), torch.tensor(0.6327418088912964)), + (str(b3.id), torch.tensor(0.19631657004356384)), + (str(b2.id), torch.tensor(0.5965832471847534)), + (str(b1.id), torch.tensor(0.6942967176437378)), ] assert results2 == [ - (str(b1.id), torch.tensor(0.1231116801500320)), - (str(b2.id), torch.tensor(0.223122775554656986)), - (str(b3.id), torch.tensor(0.5845098495483398)), + (str(b1.id), torch.tensor(0.10439890623092651)), + (str(b2.id), torch.tensor(0.13353225588798523)), + (str(b3.id), torch.tensor(0.6433180570602417)), ] From 189480b505e4d58ed4df391ded3175caee67220e Mon Sep 17 00:00:00 2001 From: abarolo Date: Thu, 26 Sep 2024 10:40:45 +0100 Subject: [PATCH 04/17] debug ci --- tests/related_barriers/test_baseline_performance.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 438594128..d99036d08 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -87,6 +87,10 @@ def test_case_1(manager, TEST_CASE_1): quantity=5, ) + from pprint import pprint + pprint(results1) + pprint(results2) + assert results1 == [ (str(b3.id), torch.tensor(0.19631657004356384)), (str(b2.id), torch.tensor(0.5965832471847534)), From 6232be248a5fea5752c0df0efaba201cfbdfde93 Mon Sep 17 00:00:00 2001 From: abarolo Date: Thu, 26 Sep 2024 11:00:14 +0100 Subject: [PATCH 05/17] f --- tests/related_barriers/test_baseline_performance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index d99036d08..04fdd884b 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -88,8 +88,8 @@ def test_case_1(manager, TEST_CASE_1): ) from pprint import pprint - pprint(results1) - pprint(results2) + pprint([(r[0], r[1].item()) for r in results1]) + pprint([(r[0], r[1].item()) for r in results2]) assert results1 == [ (str(b3.id), torch.tensor(0.19631657004356384)), From 66ffaf5d0bb68a8c115b081288c03bd8c476821a Mon Sep 17 00:00:00 2001 From: abarolo Date: Thu, 26 Sep 2024 11:35:47 +0100 Subject: [PATCH 06/17] f --- api/related_barriers/manager.py | 5 ++-- .../test_baseline_performance.py | 30 ++++++++----------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index ef0b36d79..34c939cab 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -2,7 +2,7 @@ import time from functools import wraps from typing import Dict, List, Optional - +import torch import numpy from django.core.cache import cache from sentence_transformers import SentenceTransformer, util @@ -249,7 +249,8 @@ def get_similar_barriers_searched( for k, v in barrier_scores.items() if v > similarity_threshold and k != embedded_index } - return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + return [(r[0], torch.round(r[1], decimals=4)) for r in results] manager: Optional[RelatedBarrierManager] = None diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 04fdd884b..05657f328 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -87,19 +87,15 @@ def test_case_1(manager, TEST_CASE_1): quantity=5, ) - from pprint import pprint - pprint([(r[0], r[1].item()) for r in results1]) - pprint([(r[0], r[1].item()) for r in results2]) - assert results1 == [ - (str(b3.id), torch.tensor(0.19631657004356384)), - (str(b2.id), torch.tensor(0.5965832471847534)), - (str(b1.id), torch.tensor(0.6942967176437378)), + (str(b3.id), torch.tensor(0.1963)), + (str(b2.id), torch.tensor(0.5966)), + (str(b1.id), torch.tensor(0.6943)), ] assert results2 == [ - (str(b1.id), torch.tensor(0.10439890623092651)), - (str(b2.id), torch.tensor(0.13353225588798523)), - (str(b3.id), torch.tensor(0.6433180570602417)), + (str(b1.id), torch.tensor(0.1044)), + (str(b2.id), torch.tensor(0.1335)), + (str(b3.id), torch.tensor(0.6433)), ] @@ -117,11 +113,11 @@ def test_case_2(manager, TEST_CASE_2): ) assert results1 == [ - (str(barriers[4].id), torch.tensor(0.03359394893050194)), - (str(barriers[7].id), torch.tensor(0.08189050108194351)), - (str(barriers[8].id), torch.tensor(0.09803038835525513)), - (str(barriers[6].id), torch.tensor(0.10469925403594971)), - (str(barriers[3].id), torch.tensor(0.18889553844928741)), - (str(barriers[1].id), torch.tensor(0.5271985530853271)), - (str(barriers[0].id), torch.tensor(0.7035487294197083)), + (str(barriers[4].id), torch.tensor(0.0336)), + (str(barriers[7].id), torch.tensor(0.0819)), + (str(barriers[8].id), torch.tensor(0.0980)), + (str(barriers[6].id), torch.tensor(0.1047)), + (str(barriers[3].id), torch.tensor(0.1889)), + (str(barriers[1].id), torch.tensor(0.5272)), + (str(barriers[0].id), torch.tensor(0.7035)), ] From 28398e406af67da4f060bd875dbcb14bd0a92cf2 Mon Sep 17 00:00:00 2001 From: abarolo Date: Thu, 26 Sep 2024 11:48:14 +0100 Subject: [PATCH 07/17] format --- api/related_barriers/manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index 34c939cab..84d7975c1 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -2,8 +2,9 @@ import time from functools import wraps from typing import Dict, List, Optional -import torch + import numpy +import torch from django.core.cache import cache from sentence_transformers import SentenceTransformer, util From 9cd2224d968198ea7973e96a8c9f238b15b1376b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 26 Sep 2024 12:20:42 +0000 Subject: [PATCH 08/17] fix tests --- config/settings/test.py | 2 +- tests/related_barriers/test_baseline_performance.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/config/settings/test.py b/config/settings/test.py index 38e61cd14..cff42b499 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -15,7 +15,7 @@ ] DOCUMENT_BUCKET = "test-bucket" -CACHES = {"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}} +CACHES = {"default": {"BACKEND": "django.core.cache.backends.dummy.DummyCache"}} # Stop WhiteNoise emitting warnings when running tests without running collectstatic first WHITENOISE_AUTOREFRESH = True diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 05657f328..11d107533 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -14,8 +14,9 @@ @pytest.fixture -def manager(): +def manager(settings): # Loads transformer model + settings.CACHES = {"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}} return RelatedBarrierManager() From 86d655d533fab6475ec91e48df9af52bfd269e0e Mon Sep 17 00:00:00 2001 From: abarolo Date: Thu, 26 Sep 2024 13:21:44 +0100 Subject: [PATCH 09/17] format --- tests/related_barriers/test_baseline_performance.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 11d107533..c0ecec74a 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -16,7 +16,9 @@ @pytest.fixture def manager(settings): # Loads transformer model - settings.CACHES = {"default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}} + settings.CACHES = { + "default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"} + } return RelatedBarrierManager() From 1d6af99897045e03f6b807ca671fff56aaa5fd09 Mon Sep 17 00:00:00 2001 From: abarolo Date: Mon, 30 Sep 2024 13:54:47 +0100 Subject: [PATCH 10/17] dev test --- api/barriers/signals/handlers.py | 14 +----- .../management/commands/related_barriers.py | 50 +++++++++++++++++++ api/related_barriers/manager.py | 43 +++++----------- api/related_barriers/tasks.py | 27 ++++++++++ .../test_baseline_performance.py | 42 ++++++++++++++-- 5 files changed, 130 insertions(+), 46 deletions(-) create mode 100644 api/related_barriers/management/commands/related_barriers.py create mode 100644 api/related_barriers/tasks.py diff --git a/api/barriers/signals/handlers.py b/api/barriers/signals/handlers.py index 6ae785415..84bac0721 100644 --- a/api/barriers/signals/handlers.py +++ b/api/barriers/signals/handlers.py @@ -24,6 +24,7 @@ from api.related_barriers import manager from api.related_barriers.constants import BarrierEntry from api.related_barriers.manager import BARRIER_UPDATE_FIELDS +from api.related_barriers.tasks import update_related_barrier logger = logging.getLogger(__name__) @@ -228,15 +229,4 @@ def related_barrier_update_embeddings(sender, instance, *args, **kwargs): ) if changed and not current_barrier_object.draft: - if not manager.manager: - manager.init() - try: - manager.manager.update_barrier( - BarrierEntry( - id=str(current_barrier_object.id), - barrier_corpus=manager.barrier_to_corpus(instance), - ) - ) - except Exception as e: - # We don't want barrier embedding updates to break worker so just log error - logger.critical(str(e)) + update_related_barrier.delay(barrier_id=str(instance.pk)) diff --git a/api/related_barriers/management/commands/related_barriers.py b/api/related_barriers/management/commands/related_barriers.py new file mode 100644 index 000000000..bb3c0ad7f --- /dev/null +++ b/api/related_barriers/management/commands/related_barriers.py @@ -0,0 +1,50 @@ +import time + +from django.core.management import BaseCommand + +from api.related_barriers import manager + + +class Command(BaseCommand): + help = "Attach policy teams to barriers" + + def add_arguments(self, parser): + parser.add_argument( + "--flush", type=bool, help="Flush cache" + ) + parser.add_argument( + "--reindex", type=bool, help="Reindex cache" + ) + parser.add_argument( + "--stats", type=bool, help="Get Related Barrier cache stats" + ) + + def handle(self, *args, **options): + flush = options["flush"] + reindex = options["reindex"] + stats = options["stats"] + + rb_manager = manager.RelatedBarrierManager() + + barrier_count = len(rb_manager.get_barrier_ids()) + + if stats: + print(f"Barrier Count: {barrier_count}") + + if reindex: + print(f'Reindexing {barrier_count} barriers') + data = manager.get_data() + + rb_manager.flush() + + s = time.time() + rb_manager.set_data(data) + end = time.time() + print('Training time: ', end - s) + return + + if flush: + print('Flushing') + rb_manager.flush() + return + diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index 84d7975c1..af53f3cc4 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -165,7 +165,7 @@ def remove_barrier(self, barrier: BarrierEntry, barrier_ids=None) -> None: @timing def update_barrier(self, barrier: BarrierEntry) -> None: logger.info(f"(Related Barriers): update_barrier {barrier.id}") - barrier_ids = manager.get_barrier_ids() + barrier_ids = self.get_barrier_ids() if barrier.id in barrier_ids: self.remove_barrier(barrier, barrier_ids) self.add_barrier(barrier, barrier_ids) @@ -201,7 +201,8 @@ def get_similar_barriers( for k, v in barrier_scores.items() if v > similarity_threshold and k != barrier.id } - return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + return [(r[0], torch.round(r[1], decimals=4)) for r in results] @timing def get_similar_barriers_searched( @@ -270,35 +271,27 @@ def get_data_2() -> List[Dict]: from api.barriers.models import Barrier data_dictionary = [] - barriers = Barrier.objects.filter(archived=False).exclude(draft=True) + barriers = Barrier.objects.prefetch_related( + 'interactions_documents' + ).filter(archived=False).exclude(draft=True) for barrier in barriers: corpus_object = {} companies_affected_list = "" if barrier.companies: - companies_affected_list = ", ".join( - [company["name"] for company in barrier.companies] - ) + companies_affected_list = ", ".join([company["name"] for company in barrier.companies]) other_organisations_affected_list = "" if barrier.related_organisations: - other_organisations_affected_list = ", ".join( - [company["name"] for company in barrier.related_organisations] - ) + other_organisations_affected_list = ", ".join([company["name"] for company in barrier.related_organisations]) - notes_text_list = ", ".join( - [note.text for note in barrier.interactions_documents.all()] - ) + notes_text_list = ", ".join([note.text for note in barrier.interactions_documents.all()]) - sectors_list = [ - get_sector(str(sector_id))["name"] for sector_id in barrier.sectors - ] + sectors_list = [get_sector(str(sector_id))["name"] for sector_id in barrier.sectors if get_sector(str(sector_id))] sectors_list.append(get_sector(barrier.main_sector)["name"]) sectors_text = ", ".join(sectors_list) - overseas_region_text = get_barriers_overseas_region( - barrier.country, barrier.trading_bloc - ) + overseas_region_text = get_barriers_overseas_region(barrier.country, barrier.trading_bloc) estimated_resolution_date_text = "" if barrier.estimated_resolution_date: @@ -306,19 +299,7 @@ def get_data_2() -> List[Dict]: estimated_resolution_date_text = f"Estimated to be resolved on {date}." corpus_object["id"] = barrier.id - corpus_object["barrier_corpus"] = ( - f"{barrier.title}. " - f"{barrier.summary}. " - f"{sectors_text} " - f"{barrier.country_name}. " - f"{overseas_region_text}. " - f"{companies_affected_list} " - f"{other_organisations_affected_list} " - f"{notes_text_list}. " - f"{barrier.status_summary}. " - f"{estimated_resolution_date_text}. " - f"{barrier.export_description}." - ) + corpus_object["barrier_corpus"] = (f"{barrier.title}. {barrier.summary}. {sectors_text} {barrier.country_name}. {overseas_region_text}. {companies_affected_list} {other_organisations_affected_list} {notes_text_list}. {barrier.status_summary}. {estimated_resolution_date_text}. {barrier.export_description}.") data_dictionary.append(corpus_object) diff --git a/api/related_barriers/tasks.py b/api/related_barriers/tasks.py new file mode 100644 index 000000000..a420a3cd3 --- /dev/null +++ b/api/related_barriers/tasks.py @@ -0,0 +1,27 @@ +import logging + +from celery import shared_task + +from api.barriers.models import Barrier +from api.related_barriers import manager +from api.related_barriers.constants import BarrierEntry + + +logger = logging.getLogger(__name__) + + +@shared_task +def update_related_barrier(barrier_id: str): + logger.info(f'Updating related barrier embeddings for: {barrier_id}') + if not manager.manager: + manager.init() + try: + manager.manager.update_barrier( + BarrierEntry( + id=barrier_id, + barrier_corpus=manager.barrier_to_corpus(Barrier.objects.get(pk=barrier_id)), + ) + ) + except Exception as e: + # We don't want barrier embedding updates to break worker so just log error + logger.critical(str(e)) diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index c0ecec74a..94bf4d885 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -7,7 +7,8 @@ from sentence_transformers import SentenceTransformer from api.barriers.models import Barrier -from api.related_barriers.manager import RelatedBarrierManager +from api.related_barriers.constants import BarrierEntry +from api.related_barriers.manager import RelatedBarrierManager, barrier_to_corpus from tests.barriers.factories import BarrierFactory pytestmark = [pytest.mark.django_db] @@ -66,11 +67,29 @@ def TEST_CASE_2() -> Tuple[List[Dict], List[Barrier]]: ], barriers +@pytest.fixture +def TEST_CASE_3() -> Tuple[List[Dict], List[Barrier]]: + barriers = [ + BarrierFactory( + title="Translating Trademark for Whisky Products in Quebec", + summary="Bill 96, published by the Quebec Government years ago, changed general regulation regarding French language translations. They ask business to translate everything into French, including trademarks which causes particular difficulties. SWA are waiting for a regulation on the implementation of this as there is a lack of clarity where in the draft regulation they stated that trademarks are exempted but if it includes a general term or description of the product, this should be translated." + ), + BarrierFactory( + title="Offshore wind: local content; skills & workforce engagement", + summary="Taiwan requested local content for offshore wind project but Taiwanese firms currently lack capabilities and experience, which combined with weak domestic competition, has in many cases resulted in inflated costs and unnecessary project delays. In 2021, Taiwan Energy Administration allowed some additional flexibility as developers only needed to source 60% of items locally, compared to 100% previously. This removes the absolute market access barrier but still not enough." + ) + ] + + return [ + {"id": str(b.id), "barrier_corpus": f"{barrier_to_corpus(b)}"} for b in barriers + ], barriers + + def test_transformer_loads(manager): assert isinstance(manager.model, SentenceTransformer) -def test_case_1(manager, TEST_CASE_1): +def test_search_case_1(manager, TEST_CASE_1): training_data, barriers = TEST_CASE_1 b1, b2, b3 = barriers @@ -102,7 +121,7 @@ def test_case_1(manager, TEST_CASE_1): ] -def test_case_2(manager, TEST_CASE_2): +def test_search_case_2(manager, TEST_CASE_2): training_data, barriers = TEST_CASE_2 manager.set_data(training_data) @@ -124,3 +143,20 @@ def test_case_2(manager, TEST_CASE_2): (str(barriers[1].id), torch.tensor(0.5272)), (str(barriers[0].id), torch.tensor(0.7035)), ] + + +def test_related_barriers_1(manager, TEST_CASE_3): + training_data, barriers = TEST_CASE_3 + + manager.set_data(training_data) + + assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) + + barrier_entry = BarrierEntry(id=str(barriers[0].id), barrier_corpus=barrier_to_corpus(barriers[0])) + results1 = manager.get_similar_barriers( + barrier_entry, + similarity_threshold=0, + quantity=20, + ) + + assert results1 == [(str(barriers[1].id), torch.tensor(0.1221))] From f9f31018fa2e4035095faae32a85ba383410b4f0 Mon Sep 17 00:00:00 2001 From: abarolo Date: Mon, 30 Sep 2024 21:47:59 +0100 Subject: [PATCH 11/17] daily task to reindex --- api/related_barriers/tasks.py | 8 ++++++++ config/settings/base.py | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/api/related_barriers/tasks.py b/api/related_barriers/tasks.py index a420a3cd3..9a9a81f15 100644 --- a/api/related_barriers/tasks.py +++ b/api/related_barriers/tasks.py @@ -1,6 +1,7 @@ import logging from celery import shared_task +from django.core.management import call_command from api.barriers.models import Barrier from api.related_barriers import manager @@ -25,3 +26,10 @@ def update_related_barrier(barrier_id: str): except Exception as e: # We don't want barrier embedding updates to break worker so just log error logger.critical(str(e)) + + +@shared_task +def reindex_related_barriers(): + """Schedule daily task to reindex""" + call_command("related_barriers", '--reindex', '1') + diff --git a/config/settings/base.py b/config/settings/base.py index 5691ce711..5795d7cad 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -417,6 +417,12 @@ CELERY_BEAT_SCHEDULE = {} if not DEBUG: + # Runs daily at midnight + CELERY_BEAT_SCHEDULE["reindex_related_barriers"] = { + "task": "api.related_barriers.tasks.reindex_related_barriers", + "schedule": crontab(minute=0, hour=0), + } + # Runs daily at 6am CELERY_BEAT_SCHEDULE["send_notification_emails"] = { "task": "api.user.tasks.send_notification_emails", From 5f4f7e4429e7c96c065d5e48c1b5340f4153c967 Mon Sep 17 00:00:00 2001 From: abarolo Date: Tue, 1 Oct 2024 09:18:55 +0100 Subject: [PATCH 12/17] no task --- api/barriers/signals/handlers.py | 4 +--- api/related_barriers/tasks.py | 1 - tests/barriers/signals/test_handlers.py | 4 ++-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/api/barriers/signals/handlers.py b/api/barriers/signals/handlers.py index 84bac0721..18504e1c5 100644 --- a/api/barriers/signals/handlers.py +++ b/api/barriers/signals/handlers.py @@ -21,8 +21,6 @@ send_top_priority_notification, ) from api.metadata.constants import TOP_PRIORITY_BARRIER_STATUS -from api.related_barriers import manager -from api.related_barriers.constants import BarrierEntry from api.related_barriers.manager import BARRIER_UPDATE_FIELDS from api.related_barriers.tasks import update_related_barrier @@ -229,4 +227,4 @@ def related_barrier_update_embeddings(sender, instance, *args, **kwargs): ) if changed and not current_barrier_object.draft: - update_related_barrier.delay(barrier_id=str(instance.pk)) + update_related_barrier(barrier_id=str(instance.pk)) diff --git a/api/related_barriers/tasks.py b/api/related_barriers/tasks.py index 9a9a81f15..a03d37cdd 100644 --- a/api/related_barriers/tasks.py +++ b/api/related_barriers/tasks.py @@ -11,7 +11,6 @@ logger = logging.getLogger(__name__) -@shared_task def update_related_barrier(barrier_id: str): logger.info(f'Updating related barrier embeddings for: {barrier_id}') if not manager.manager: diff --git a/tests/barriers/signals/test_handlers.py b/tests/barriers/signals/test_handlers.py index b72507cee..cbaa8638e 100644 --- a/tests/barriers/signals/test_handlers.py +++ b/tests/barriers/signals/test_handlers.py @@ -9,8 +9,8 @@ def test_related_barrier_handler(): barrier = BarrierFactory() - with mock.patch("api.barriers.signals.handlers.manager") as mock_manager: + with mock.patch("api.barriers.signals.handlers.update_related_barrier") as mock_update_related_barrier: barrier.title = "New Title" barrier.save() - assert mock_manager.manager.update_barrier.call_count == 1 + mock_update_related_barrier.assert_called_once_with(barrier_id=str(barrier.pk)) From 0a35ca8a265de2a5cbd7a85779e33b40b65180b6 Mon Sep 17 00:00:00 2001 From: abarolo Date: Tue, 1 Oct 2024 09:21:43 +0100 Subject: [PATCH 13/17] format --- .../management/commands/related_barriers.py | 15 +++---- api/related_barriers/manager.py | 43 ++++++++++++++----- api/related_barriers/tasks.py | 10 ++--- tests/barriers/signals/test_handlers.py | 4 +- .../test_baseline_performance.py | 26 +++++++++-- 5 files changed, 67 insertions(+), 31 deletions(-) diff --git a/api/related_barriers/management/commands/related_barriers.py b/api/related_barriers/management/commands/related_barriers.py index bb3c0ad7f..7e5a8b0a1 100644 --- a/api/related_barriers/management/commands/related_barriers.py +++ b/api/related_barriers/management/commands/related_barriers.py @@ -9,12 +9,8 @@ class Command(BaseCommand): help = "Attach policy teams to barriers" def add_arguments(self, parser): - parser.add_argument( - "--flush", type=bool, help="Flush cache" - ) - parser.add_argument( - "--reindex", type=bool, help="Reindex cache" - ) + parser.add_argument("--flush", type=bool, help="Flush cache") + parser.add_argument("--reindex", type=bool, help="Reindex cache") parser.add_argument( "--stats", type=bool, help="Get Related Barrier cache stats" ) @@ -32,7 +28,7 @@ def handle(self, *args, **options): print(f"Barrier Count: {barrier_count}") if reindex: - print(f'Reindexing {barrier_count} barriers') + print(f"Reindexing {barrier_count} barriers") data = manager.get_data() rb_manager.flush() @@ -40,11 +36,10 @@ def handle(self, *args, **options): s = time.time() rb_manager.set_data(data) end = time.time() - print('Training time: ', end - s) + print("Training time: ", end - s) return if flush: - print('Flushing') + print("Flushing") rb_manager.flush() return - diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index af53f3cc4..ecffef815 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -271,27 +271,41 @@ def get_data_2() -> List[Dict]: from api.barriers.models import Barrier data_dictionary = [] - barriers = Barrier.objects.prefetch_related( - 'interactions_documents' - ).filter(archived=False).exclude(draft=True) + barriers = ( + Barrier.objects.prefetch_related("interactions_documents") + .filter(archived=False) + .exclude(draft=True) + ) for barrier in barriers: corpus_object = {} companies_affected_list = "" if barrier.companies: - companies_affected_list = ", ".join([company["name"] for company in barrier.companies]) + companies_affected_list = ", ".join( + [company["name"] for company in barrier.companies] + ) other_organisations_affected_list = "" if barrier.related_organisations: - other_organisations_affected_list = ", ".join([company["name"] for company in barrier.related_organisations]) - - notes_text_list = ", ".join([note.text for note in barrier.interactions_documents.all()]) - - sectors_list = [get_sector(str(sector_id))["name"] for sector_id in barrier.sectors if get_sector(str(sector_id))] + other_organisations_affected_list = ", ".join( + [company["name"] for company in barrier.related_organisations] + ) + + notes_text_list = ", ".join( + [note.text for note in barrier.interactions_documents.all()] + ) + + sectors_list = [ + get_sector(str(sector_id))["name"] + for sector_id in barrier.sectors + if get_sector(str(sector_id)) + ] sectors_list.append(get_sector(barrier.main_sector)["name"]) sectors_text = ", ".join(sectors_list) - overseas_region_text = get_barriers_overseas_region(barrier.country, barrier.trading_bloc) + overseas_region_text = get_barriers_overseas_region( + barrier.country, barrier.trading_bloc + ) estimated_resolution_date_text = "" if barrier.estimated_resolution_date: @@ -299,7 +313,14 @@ def get_data_2() -> List[Dict]: estimated_resolution_date_text = f"Estimated to be resolved on {date}." corpus_object["id"] = barrier.id - corpus_object["barrier_corpus"] = (f"{barrier.title}. {barrier.summary}. {sectors_text} {barrier.country_name}. {overseas_region_text}. {companies_affected_list} {other_organisations_affected_list} {notes_text_list}. {barrier.status_summary}. {estimated_resolution_date_text}. {barrier.export_description}.") + corpus_object["barrier_corpus"] = ( + f"{barrier.title}. {barrier.summary}. " + f"{sectors_text} {barrier.country_name}. " + f"{overseas_region_text}. {companies_affected_list} " + f"{other_organisations_affected_list} {notes_text_list}. " + f"{barrier.status_summary}. {estimated_resolution_date_text}. " + f"{barrier.export_description}." + ) data_dictionary.append(corpus_object) diff --git a/api/related_barriers/tasks.py b/api/related_barriers/tasks.py index a03d37cdd..75adb8c0f 100644 --- a/api/related_barriers/tasks.py +++ b/api/related_barriers/tasks.py @@ -7,19 +7,20 @@ from api.related_barriers import manager from api.related_barriers.constants import BarrierEntry - logger = logging.getLogger(__name__) def update_related_barrier(barrier_id: str): - logger.info(f'Updating related barrier embeddings for: {barrier_id}') + logger.info(f"Updating related barrier embeddings for: {barrier_id}") if not manager.manager: manager.init() try: manager.manager.update_barrier( BarrierEntry( id=barrier_id, - barrier_corpus=manager.barrier_to_corpus(Barrier.objects.get(pk=barrier_id)), + barrier_corpus=manager.barrier_to_corpus( + Barrier.objects.get(pk=barrier_id) + ), ) ) except Exception as e: @@ -30,5 +31,4 @@ def update_related_barrier(barrier_id: str): @shared_task def reindex_related_barriers(): """Schedule daily task to reindex""" - call_command("related_barriers", '--reindex', '1') - + call_command("related_barriers", "--reindex", "1") diff --git a/tests/barriers/signals/test_handlers.py b/tests/barriers/signals/test_handlers.py index cbaa8638e..b4b8d392c 100644 --- a/tests/barriers/signals/test_handlers.py +++ b/tests/barriers/signals/test_handlers.py @@ -9,7 +9,9 @@ def test_related_barrier_handler(): barrier = BarrierFactory() - with mock.patch("api.barriers.signals.handlers.update_related_barrier") as mock_update_related_barrier: + with mock.patch( + "api.barriers.signals.handlers.update_related_barrier" + ) as mock_update_related_barrier: barrier.title = "New Title" barrier.save() diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 94bf4d885..4909eb637 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -72,12 +72,28 @@ def TEST_CASE_3() -> Tuple[List[Dict], List[Barrier]]: barriers = [ BarrierFactory( title="Translating Trademark for Whisky Products in Quebec", - summary="Bill 96, published by the Quebec Government years ago, changed general regulation regarding French language translations. They ask business to translate everything into French, including trademarks which causes particular difficulties. SWA are waiting for a regulation on the implementation of this as there is a lack of clarity where in the draft regulation they stated that trademarks are exempted but if it includes a general term or description of the product, this should be translated." + summary=( + "Bill 96, published by the Quebec Government years ago, changed " + "general regulation regarding French language translations. They " + "ask business to translate everything into French, including trademarks " + "which causes particular difficulties. SWA are waiting for a regulation " + "on the implementation of this as there is a lack of clarity where in " + "the draft regulation they stated that trademarks are exempted but if it " + "includes a general term or description of the product, this should be translated.", + ), ), BarrierFactory( title="Offshore wind: local content; skills & workforce engagement", - summary="Taiwan requested local content for offshore wind project but Taiwanese firms currently lack capabilities and experience, which combined with weak domestic competition, has in many cases resulted in inflated costs and unnecessary project delays. In 2021, Taiwan Energy Administration allowed some additional flexibility as developers only needed to source 60% of items locally, compared to 100% previously. This removes the absolute market access barrier but still not enough." - ) + summary=( + "Taiwan requested local content for offshore wind project but Taiwanese " + "firms currently lack capabilities and experience, which combined with " + "weak domestic competition, has in many cases resulted in inflated costs " + "and unnecessary project delays. In 2021, Taiwan Energy Administration " + "allowed some additional flexibility as developers only needed to source " + "60% of items locally, compared to 100% previously. This removes the " + "absolute market access barrier but still not enough.", + ), + ), ] return [ @@ -152,7 +168,9 @@ def test_related_barriers_1(manager, TEST_CASE_3): assert set(manager.get_barrier_ids()) == set([str(b.id) for b in barriers]) - barrier_entry = BarrierEntry(id=str(barriers[0].id), barrier_corpus=barrier_to_corpus(barriers[0])) + barrier_entry = BarrierEntry( + id=str(barriers[0].id), barrier_corpus=barrier_to_corpus(barriers[0]) + ) results1 = manager.get_similar_barriers( barrier_entry, similarity_threshold=0, From 7d05600b9b9d211c997390cc942580614c5c73d4 Mon Sep 17 00:00:00 2001 From: abarolo Date: Tue, 1 Oct 2024 09:33:21 +0100 Subject: [PATCH 14/17] update --- api/related_barriers/manager.py | 103 ++++++++++++++------------------ 1 file changed, 46 insertions(+), 57 deletions(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index ecffef815..c04b8d0d7 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -224,7 +224,7 @@ def get_similar_barriers_searched( logger.info("(Related Barriers): get_similar_barriers_searched") if not (barrier_ids := self.get_barrier_ids()): - self.set_data(get_data_2()) + self.set_data(get_data()) barrier_ids = self.get_barrier_ids() or [] if not barrier_ids: @@ -261,15 +261,6 @@ def get_similar_barriers_searched( def get_data() -> List[Dict]: from api.barriers.models import Barrier - return [ - {"id": barrier.id, "barrier_corpus": f"{barrier.title} . {barrier.summary}"} - for barrier in Barrier.objects.filter(archived=False).exclude(draft=True) - ] - - -def get_data_2() -> List[Dict]: - from api.barriers.models import Barrier - data_dictionary = [] barriers = ( Barrier.objects.prefetch_related("interactions_documents") @@ -277,52 +268,10 @@ def get_data_2() -> List[Dict]: .exclude(draft=True) ) for barrier in barriers: - corpus_object = {} - - companies_affected_list = "" - if barrier.companies: - companies_affected_list = ", ".join( - [company["name"] for company in barrier.companies] - ) - - other_organisations_affected_list = "" - if barrier.related_organisations: - other_organisations_affected_list = ", ".join( - [company["name"] for company in barrier.related_organisations] - ) - - notes_text_list = ", ".join( - [note.text for note in barrier.interactions_documents.all()] - ) - - sectors_list = [ - get_sector(str(sector_id))["name"] - for sector_id in barrier.sectors - if get_sector(str(sector_id)) - ] - sectors_list.append(get_sector(barrier.main_sector)["name"]) - sectors_text = ", ".join(sectors_list) - - overseas_region_text = get_barriers_overseas_region( - barrier.country, barrier.trading_bloc - ) - - estimated_resolution_date_text = "" - if barrier.estimated_resolution_date: - date = barrier.estimated_resolution_date.strftime("%d-%m-%Y") - estimated_resolution_date_text = f"Estimated to be resolved on {date}." - - corpus_object["id"] = barrier.id - corpus_object["barrier_corpus"] = ( - f"{barrier.title}. {barrier.summary}. " - f"{sectors_text} {barrier.country_name}. " - f"{overseas_region_text}. {companies_affected_list} " - f"{other_organisations_affected_list} {notes_text_list}. " - f"{barrier.status_summary}. {estimated_resolution_date_text}. " - f"{barrier.export_description}." - ) - - data_dictionary.append(corpus_object) + data_dictionary.append({ + "id": str(barrier.id), + "barrier_corpus": barrier_to_corpus(barrier) + }) return data_dictionary @@ -342,4 +291,44 @@ def init(): def barrier_to_corpus(barrier) -> str: - return barrier.title + ". " + barrier.summary + companies_affected_list = "" + if barrier.companies: + companies_affected_list = ", ".join( + [company["name"] for company in barrier.companies] + ) + + other_organisations_affected_list = "" + if barrier.related_organisations: + other_organisations_affected_list = ", ".join( + [company["name"] for company in barrier.related_organisations] + ) + + notes_text_list = ", ".join( + [note.text for note in barrier.interactions_documents.all()] + ) + + sectors_list = [ + get_sector(str(sector_id))["name"] + for sector_id in barrier.sectors + if get_sector(str(sector_id)) + ] + sectors_list.append(get_sector(barrier.main_sector)["name"]) + sectors_text = ", ".join(sectors_list) + + overseas_region_text = get_barriers_overseas_region( + barrier.country, barrier.trading_bloc + ) + + estimated_resolution_date_text = "" + if barrier.estimated_resolution_date: + date = barrier.estimated_resolution_date.strftime("%d-%m-%Y") + estimated_resolution_date_text = f"Estimated to be resolved on {date}." + + return ( + f"{barrier.title}. {barrier.summary}. " + f"{sectors_text} {barrier.country_name}. " + f"{overseas_region_text}. {companies_affected_list} " + f"{other_organisations_affected_list} {notes_text_list}. " + f"{barrier.status_summary}. {estimated_resolution_date_text}. " + f"{barrier.export_description}." + ) From cfe0728f09e1de83b466d83c5130a32673a513fc Mon Sep 17 00:00:00 2001 From: abarolo Date: Tue, 1 Oct 2024 10:16:26 +0100 Subject: [PATCH 15/17] test + format --- api/related_barriers/manager.py | 7 +++---- tests/related_barriers/test_baseline_performance.py | 5 +++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index c04b8d0d7..f0a9168df 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -268,10 +268,9 @@ def get_data() -> List[Dict]: .exclude(draft=True) ) for barrier in barriers: - data_dictionary.append({ - "id": str(barrier.id), - "barrier_corpus": barrier_to_corpus(barrier) - }) + data_dictionary.append( + {"id": str(barrier.id), "barrier_corpus": barrier_to_corpus(barrier)} + ) return data_dictionary diff --git a/tests/related_barriers/test_baseline_performance.py b/tests/related_barriers/test_baseline_performance.py index 4909eb637..9b956e1cb 100644 --- a/tests/related_barriers/test_baseline_performance.py +++ b/tests/related_barriers/test_baseline_performance.py @@ -97,7 +97,8 @@ def TEST_CASE_3() -> Tuple[List[Dict], List[Barrier]]: ] return [ - {"id": str(b.id), "barrier_corpus": f"{barrier_to_corpus(b)}"} for b in barriers + {"id": str(b.id), "barrier_corpus": f"{b.title} . {b.summary}"} + for b in barriers ], barriers @@ -177,4 +178,4 @@ def test_related_barriers_1(manager, TEST_CASE_3): quantity=20, ) - assert results1 == [(str(barriers[1].id), torch.tensor(0.1221))] + assert results1 == [(str(barriers[1].id), torch.tensor(0.1556))] From ec4cf47013afa5bd1841251f6649ff4d2153f8f5 Mon Sep 17 00:00:00 2001 From: abarolo Date: Tue, 1 Oct 2024 10:28:04 +0100 Subject: [PATCH 16/17] fix test --- api/related_barriers/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index f0a9168df..83821a441 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -202,7 +202,7 @@ def get_similar_barriers( if v > similarity_threshold and k != barrier.id } results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] - return [(r[0], torch.round(r[1], decimals=4)) for r in results] + return [(r[0], torch.round(torch.tensor(r[1]), decimals=4)) for r in results] @timing def get_similar_barriers_searched( From fb1da67b83d4d241e0173c547f57299d6519a06b Mon Sep 17 00:00:00 2001 From: abarolo Date: Tue, 1 Oct 2024 12:32:08 +0100 Subject: [PATCH 17/17] remove comment --- api/related_barriers/manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index 83821a441..cb52c57c9 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -31,7 +31,6 @@ def wrapper(*args, **kwargs): @timing def get_transformer(): return SentenceTransformer("all-MiniLM-L6-v2") - # return SentenceTransformer("paraphrase-MiniLM-L3-v2") BARRIER_UPDATE_FIELDS: List[str] = ["title", "summary"]