From c09041828b3023bb19b63c96dd7a07f72aa1494d Mon Sep 17 00:00:00 2001 From: abarolo Date: Wed, 7 Aug 2024 08:55:25 +0100 Subject: [PATCH] Update and working search endpoint --- api/related_barriers/manager.py | 182 +++++----------------------- api/related_barriers/serializers.py | 4 + api/related_barriers/urls.py | 2 +- api/related_barriers/views.py | 19 +-- 4 files changed, 48 insertions(+), 159 deletions(-) diff --git a/api/related_barriers/manager.py b/api/related_barriers/manager.py index 4bb4b655a..c44570b23 100644 --- a/api/related_barriers/manager.py +++ b/api/related_barriers/manager.py @@ -35,6 +35,7 @@ def wrapper(*args, **kwargs): @timing def get_transformer(): + # SentenceTransformer("all-MiniLM-L6-v2") return SentenceTransformer("paraphrase-MiniLM-L3-v2") @@ -77,6 +78,7 @@ def set_data(self, data: List[Dict]): """ Load data into memory. """ + print('Setting data') logger.info("(Related Barriers): set_data") self.flush() barrier_ids = [str(d["id"]) for d in data] @@ -210,17 +212,14 @@ def get_similar_barriers( @timing def get_similar_barriers_searched( - self, search_term: str, similarity_threshold: float, quantity: int, log_score: bool, + self, search_term: str, similarity_threshold: float, quantity: int=None, log_score=None ): logger.info(f"(Related Barriers): get_similar_barriers_seaarched") barrier_ids = self.get_barrier_ids() - # RESTORE THESE 2 LINES BEFORE COMMITTING - CACHING NEEDS TO GO FOR TESTING PURPOSESES - #if not barrier_ids: - # self.set_data(get_data()) - # DELETE THIS LINE BEFORE COMMITTING - CACHING SHOULD COME BACK FOR REAL - self.set_data(get_data()) + if not barrier_ids: + self.set_data(get_data()) barrier_ids = self.get_barrier_ids() @@ -232,58 +231,16 @@ def get_similar_barriers_searched( cosine_sim = util.cos_sim(new_embeddings, new_embeddings) - - - #embedder = SentenceTransformer("paraphrase-MiniLM-L3-v2") - embedder = SentenceTransformer("all-MiniLM-L6-v2") - corpo_list = [] - for thing in get_data(): - corpo_list.append(thing["barrier_corpus"]) - logger.critical("_________") - logger.critical(corpo_list) - logger.critical("_________") - corpus_embeddings = embedder.encode(corpo_list, convert_to_tensor=True) - top_k = min(5, 50) - query_embedding = embedder.encode(search_term, convert_to_tensor=True) - similarity_scores = embedder.similarity(query_embedding, corpus_embeddings)[0] - scores_2, indices = torch.topk(similarity_scores, k=top_k) - logger.critical("++++++++++++++++++++++++++++++++++++++") - logger.critical("++++++++++++++++++++++++++++++++++++++") - for score, idx in zip(scores_2, indices): - print(corpo_list[idx], "(Score: {:.4f})".format(score)) - logger.critical("++++++++++++++++++++++++++++++++++++++") - logger.critical("++++++++++++++++++++++++++++++++++++++") - - - - - - - - - - index = len(embeddings) - - scores = cosine_sim[index] + index_search_term_embedding = len(new_embeddings) - 1 + scores = cosine_sim[index_search_term_embedding] barrier_scores = dict(zip(new_barrier_ids, scores)) - barrier_scores = { k: v for k, v in barrier_scores.items() - if v > similarity_threshold + if v > similarity_threshold and k != embedded_index } - - - logger.critical(barrier_scores.items()) - logger.critical("-") - - barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])#[-quantity:] - - if log_score: - logger.critical(barrier_scores) - - barrier_ids = [b[0] for b in barrier_scores] - return barrier_ids + barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:] + return barrier_scores manager: Optional[RelatedBarrierManager] = None @@ -292,113 +249,40 @@ def get_similar_barriers_searched( def get_data() -> List[Dict]: from api.barriers.models import Barrier - # ./manage.py run_test + return [ + {"id": barrier.id, "barrier_corpus": f"{barrier.title} . {barrier.summary}"} + for barrier in Barrier.objects.filter(archived=False).exclude(draft=True) + ] + + +def get_data_2() -> List[Dict]: + from api.barriers.models import Barrier data_dictionary = [] barriers = Barrier.objects.filter(archived=False).exclude(draft=True) for barrier in barriers: corpus_object = {} - # Companies tests - - # - # "Barriers that affect barclays bank" - # WITH Additional sentence text - #('dde42f8f-b02b-41a6-b794-3cb6e1cdade5', tensor(0.3390)) - # WITHOUT Additional sentence text - #('dde42f8f-b02b-41a6-b794-3cb6e1cdade5', tensor(0.2761)) - # (Barrier with Barclays PLC not the highest match) - # - # "Barclays bank" - # No results - one word match results in too low a threshold for result - # - # "show me barriers that are related to Barclays bank" - # WITH Additional sentence text - # ('dde42f8f-b02b-41a6-b794-3cb6e1cdade5', tensor(0.4561)) - # barrier with Barlays under companies comes 2nd in matching - the highest - # has a export description of "Deal account account tough" - which could be words - # related to the query word "bank" - # WITHOUT Additional sentence text - # Expected barrier did not appear - companies_affected_list = "" if barrier.companies: - for company in barrier.companies: - company_affected_corpus = ( - "Company called " + company["name"] + " is affected by this barrier. " - ) - companies_affected_list = companies_affected_list + company_affected_corpus - #companies_affected_list = companies_affected_list + str(company) + " " + companies_affected_list = ", ".join([company['name'] for company in barrier.companies]) other_organisations_affected_list = "" if barrier.related_organisations: - for company in barrier.related_organisations: - other_organisations_affected_corpus = ( - "Company called " + company["name"] + " is affected by this barrier. " - ) - other_organisations_affected_list = other_organisations_affected_list + other_organisations_affected_corpus - #other_organisations_affected_list = other_organisations_affected_list + str(company) + " " - - # Notes Tests - - # - # "Barriers that need Santino Molinaros attention" - notes_text_list = "" - for note in barrier.interactions_documents.all(): - notes_text_list = notes_text_list + note.text + " " - - - # Sectors Tests - - # - # "Barriers that are affecting the maritime sector" + other_organisations_affected_list = ", ".join([company['name'] for company in barrier.related_organisations]) + + notes_text_list = ", ".join([note.text for note in barrier.interactions_documents.all()]) + sectors_list = [get_sector(str(sector_id))["name"] for sector_id in barrier.sectors] sectors_list.append(get_sector(barrier.main_sector)["name"]) - sectors_text = "" - if len(sectors_list) > 0: - for sector_name in sectors_list: - sectors_text = sectors_text + f" Affects the {sector_name} sector. Related to the {sector_name} sector." - - # Overseas Region Tests - - # - # "Barriers owned by the South Asia region team" - # - # "Barriers affecting South Asia region" - - overseas_region_name = get_barriers_overseas_region(barrier.country, barrier.trading_bloc) - overseas_region_text = f"Belongs to the {overseas_region_name} regional team. Affects the {overseas_region_name} region. Owned by the {overseas_region_name} regional team" - #overseas_region_text = f"{overseas_region_name}" - - # ERD tests - - # - # "Barriers which are estimated to be resolved on 23 June 2024" - # Expected result not returned, but did get results back - # - # "barrier is estimated to be resolved on 23 June 2024" - # Expected result not returned, but did get results back - - estimated_resolution_date_text = "" - date_formats = [ - "%d-%m-%Y", - "%d/%m/%Y", - "%d %m %Y", - "%d-%m-%y", - "%d/%m/%y", - "%d %m %y", - "%d %b %Y", - "%d %B %Y", - "%x", - "%M %y", - "%M %Y", - "%M %d %y", - "%M %d %Y" - ] + sectors_text = ", ".join(sectors_list) + + overseas_region_text = get_barriers_overseas_region(barrier.country, barrier.trading_bloc) + + estimated_resolution_date_text = '' if barrier.estimated_resolution_date: - for date_format in date_formats: - formatted_date = barrier.estimated_resolution_date.strftime(date_format) - #if formatted_date == "23 June 2024": - # logger.critical("======================") - # logger.critical(barrier.id) - # logger.critical("======================") - estimated_resolution_date_text = estimated_resolution_date_text + f"Estimated to be resolved on {formatted_date}. " - #estimated_resolution_date_text = str(barrier.estimated_resolution_date) + date = barrier.estimated_resolution_date.strftime("%d-%m-%Y") + estimated_resolution_date_text = f"Estimated to be resolved on {date}." corpus_object["id"] = barrier.id corpus_object["barrier_corpus"] = ( @@ -415,14 +299,11 @@ def get_data() -> List[Dict]: f"{barrier.export_description}." ) - logger.critical("-") - logger.critical(corpus_object["barrier_corpus"]) - logger.critical("-") - data_dictionary.append(corpus_object) return data_dictionary + def init(): global manager @@ -430,6 +311,7 @@ def init(): raise Exception("Related Barrier Manager already set") manager = RelatedBarrierManager() + if not manager.get_barrier_ids(): logger.info("(Related Barriers): Initialising)") data = get_data() diff --git a/api/related_barriers/serializers.py b/api/related_barriers/serializers.py index 03297a364..7c8221864 100644 --- a/api/related_barriers/serializers.py +++ b/api/related_barriers/serializers.py @@ -13,3 +13,7 @@ class BarrierRelatedListSerializer(serializers.Serializer): status = StatusField(required=False) location = serializers.CharField(read_only=True) similarity = serializers.FloatField(read_only=True) + + +class SearchRequest(serializers.Serializer): + search_term = serializers.CharField(required=True) diff --git a/api/related_barriers/urls.py b/api/related_barriers/urls.py index b84f98a20..13e665f01 100644 --- a/api/related_barriers/urls.py +++ b/api/related_barriers/urls.py @@ -13,6 +13,6 @@ "barriers//related-barriers", related_barriers, name="related-barriers" ), path( - "barriers/dashboard/", related_barriers_search, name="related-barriers-search" + "related-barriers", related_barriers_search, name="related-barriers-search" ), ] diff --git a/api/related_barriers/views.py b/api/related_barriers/views.py index cfa839a3a..604b33911 100644 --- a/api/related_barriers/views.py +++ b/api/related_barriers/views.py @@ -11,7 +11,7 @@ SIMILARITY_THRESHOLD, BarrierEntry, ) -from api.related_barriers.serializers import BarrierRelatedListSerializer +from api.related_barriers.serializers import BarrierRelatedListSerializer, SearchRequest logger = logging.getLogger(__name__) @@ -42,26 +42,29 @@ def related_barriers(request, pk) -> Response: ).data ) -#@api_view(["GET"]) -def related_barriers_search(request, search_term, log_score) -> Response: + +@api_view(["GET"]) +def related_barriers_search(request) -> Response: """ Return a list of related barriers given a search term/phrase """ + serializer = SearchRequest(data=request.GET) + serializer.is_valid(raise_exception=True) + if manager.manager is None: manager.init() - similar_barrier_ids = manager.manager.get_similar_barriers_searched( - search_term=search_term, + barrier_scores = manager.manager.get_similar_barriers_searched( + search_term=serializer.data['search_term'], similarity_threshold=SIMILARITY_THRESHOLD, quantity=SIMILAR_BARRIERS_LIMIT, - log_score=log_score ) - return similar_barrier_ids + barrier_ids = [b[0] for b in barrier_scores] return Response( BarrierRelatedListSerializer( - Barrier.objects.filter(id__in=similar_barrier_ids), many=True + Barrier.objects.filter(id__in=barrier_ids), many=True ).data )