Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
abarolo authored and abarolo committed Aug 7, 2024
1 parent c090418 commit 9e387d6
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 34 deletions.
60 changes: 45 additions & 15 deletions api/related_barriers/management/commands/run_search_related_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging

from django.core.management import BaseCommand

from api.barriers.models import Barrier
from api.barriers.tasks import get_barriers_overseas_region
from api.interactions.models import Interaction
from api.metadata.utils import get_sector
from api.related_barriers.views import related_barriers, related_barriers_search

import logging
logger = logging.getLogger(__name__)


class Command(BaseCommand):

# ./manage.py run_search_related_test --searchterm 'Company called BARCLAYS PLC is affected by this barrier'
Expand All @@ -24,26 +27,53 @@ def handle(self, *args, **options):
logger.critical("+++++++++")

logger.critical("SCORES -----")
response_data = related_barriers_search("request_dummy",options["searchterm"],log_score=True)
response_data = related_barriers_search(
"request_dummy", options["searchterm"], log_score=True
)
logger.critical("------------")

for bar_id in response_data:#[-6:]:
for bar_id in response_data: # [-6:]:
if bar_id != "search_term":
logger.critical("-")
barrier_detail = Barrier.objects.get(pk=bar_id)
logger.critical("ID = " + str(bar_id))
logger.critical("TITLE = "+ str(barrier_detail.title))
logger.critical("SUMMARY = "+ str(barrier_detail.summary))
sectors_list = [get_sector(sector)["name"] for sector in barrier_detail.sectors]
logger.critical("SECTORS = "+ str(get_sector(barrier_detail.main_sector)["name"]) + " " + str(sectors_list))
logger.critical("COUNTRY = "+ str(get_barriers_overseas_region(barrier_detail.country, barrier_detail.trading_bloc)))
logger.critical("COMPANIES = "+ str(barrier_detail.companies))
logger.critical("OTHER COMPANIES = "+ str(barrier_detail.related_organisations))
logger.critical("STATUS SUMMARY = "+ str(barrier_detail.status_summary))
logger.critical("EST. RESOLUTION DATE = "+ str(barrier_detail.estimated_resolution_date))
logger.critical("EXPORT DESC. = "+ str(barrier_detail.export_description))
notes_list = [note.text for note in barrier_detail.interactions_documents.all()]
logger.critical("NOTES = "+ str(notes_list))
logger.critical("TITLE = " + str(barrier_detail.title))
logger.critical("SUMMARY = " + str(barrier_detail.summary))
sectors_list = [
get_sector(sector)["name"] for sector in barrier_detail.sectors
]
logger.critical(
"SECTORS = "
+ str(get_sector(barrier_detail.main_sector)["name"])
+ " "
+ str(sectors_list)
)
logger.critical(
"COUNTRY = "
+ str(
get_barriers_overseas_region(
barrier_detail.country, barrier_detail.trading_bloc
)
)
)
logger.critical("COMPANIES = " + str(barrier_detail.companies))
logger.critical(
"OTHER COMPANIES = " + str(barrier_detail.related_organisations)
)
logger.critical(
"STATUS SUMMARY = " + str(barrier_detail.status_summary)
)
logger.critical(
"EST. RESOLUTION DATE = "
+ str(barrier_detail.estimated_resolution_date)
)
logger.critical(
"EXPORT DESC. = " + str(barrier_detail.export_description)
)
notes_list = [
note.text for note in barrier_detail.interactions_documents.all()
]
logger.critical("NOTES = " + str(notes_list))
logger.critical("-")

logger.critical("+++++++++")
Expand Down
45 changes: 30 additions & 15 deletions api/related_barriers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from typing import Dict, List, Optional

import numpy
import torch
from django.core.cache import cache
from django.db.models import CharField
from django.db.models import Value as V
from django.db.models.functions import Concat
from sentence_transformers import SentenceTransformer, util
import torch

from api.barriers.tasks import get_barriers_overseas_region
from api.interactions.models import Interaction
from api.metadata.utils import get_sector
from api.related_barriers.constants import BarrierEntry

from api.metadata.utils import (get_sector)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -78,7 +77,6 @@ def set_data(self, data: List[Dict]):
"""
Load data into memory.
"""
print('Setting data')
logger.info("(Related Barriers): set_data")
self.flush()
barrier_ids = [str(d["id"]) for d in data]
Expand Down Expand Up @@ -208,25 +206,32 @@ def get_similar_barriers(

barrier_ids = [b[0] for b in barrier_scores]
return barrier_ids


@timing
def get_similar_barriers_searched(
self, search_term: str, similarity_threshold: float, quantity: int=None, log_score=None
self,
search_term: str,
similarity_threshold: float,
quantity: int = None,
log_score=None,
):

logger.info(f"(Related Barriers): get_similar_barriers_seaarched")
barrier_ids = self.get_barrier_ids()

if not barrier_ids:
self.set_data(get_data())
self.set_data(get_data())

barrier_ids = self.get_barrier_ids()

embedded_index = "search_term"
search_term_embedding = self.model.encode(search_term, convert_to_tensor=True).numpy()
search_term_embedding = self.model.encode(
search_term, convert_to_tensor=True
).numpy()
embeddings = self.get_embeddings()
new_embeddings = numpy.vstack([embeddings, search_term_embedding]) # append embedding
new_embeddings = numpy.vstack(
[embeddings, search_term_embedding]
) # append embedding
new_barrier_ids = barrier_ids + [embedded_index] # append barrier_id

cosine_sim = util.cos_sim(new_embeddings, new_embeddings)
Expand Down Expand Up @@ -265,21 +270,31 @@ def get_data_2() -> List[Dict]:

companies_affected_list = ""
if barrier.companies:
companies_affected_list = ", ".join([company['name'] for company in barrier.companies])
companies_affected_list = ", ".join(
[company["name"] for company in barrier.companies]
)

other_organisations_affected_list = ""
if barrier.related_organisations:
other_organisations_affected_list = ", ".join([company['name'] for company in barrier.related_organisations])
other_organisations_affected_list = ", ".join(
[company["name"] for company in barrier.related_organisations]
)

notes_text_list = ", ".join([note.text for note in barrier.interactions_documents.all()])
notes_text_list = ", ".join(
[note.text for note in barrier.interactions_documents.all()]
)

sectors_list = [get_sector(str(sector_id))["name"] for sector_id in barrier.sectors]
sectors_list = [
get_sector(str(sector_id))["name"] for sector_id in barrier.sectors
]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ", ".join(sectors_list)

overseas_region_text = get_barriers_overseas_region(barrier.country, barrier.trading_bloc)
overseas_region_text = get_barriers_overseas_region(
barrier.country, barrier.trading_bloc
)

estimated_resolution_date_text = ''
estimated_resolution_date_text = ""
if barrier.estimated_resolution_date:
date = barrier.estimated_resolution_date.strftime("%d-%m-%Y")
estimated_resolution_date_text = f"Estimated to be resolved on {date}."
Expand Down
4 changes: 1 addition & 3 deletions api/related_barriers/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,5 @@
path(
"barriers/<uuid:pk>/related-barriers", related_barriers, name="related-barriers"
),
path(
"related-barriers", related_barriers_search, name="related-barriers-search"
),
path("related-barriers", related_barriers_search, name="related-barriers-search"),
]
2 changes: 1 addition & 1 deletion api/related_barriers/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def related_barriers_search(request) -> Response:
manager.init()

barrier_scores = manager.manager.get_similar_barriers_searched(
search_term=serializer.data['search_term'],
search_term=serializer.data["search_term"],
similarity_threshold=SIMILARITY_THRESHOLD,
quantity=SIMILAR_BARRIERS_LIMIT,
)
Expand Down

0 comments on commit 9e387d6

Please sign in to comment.