Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
abarolo authored and abarolo committed Aug 7, 2024
1 parent 39e1788 commit 754792f
Show file tree
Hide file tree
Showing 5 changed files with 794 additions and 601 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.core.management import BaseCommand
from api.barriers.models import Barrier
from api.barriers.tasks import get_barriers_overseas_region
from api.interactions.models import Interaction
from api.metadata.utils import get_sector
from api.related_barriers.views import related_barriers, related_barriers_search
Expand All @@ -9,7 +10,7 @@

class Command(BaseCommand):

# ./manage.py run_search_related_test --searchterm 'barrier is estimated to be resolved on 23 June 2024'
# ./manage.py run_search_related_test --searchterm 'Company called BARCLAYS PLC is affected by this barrier'

help = "Test the related barriers engine with a search term"

Expand All @@ -26,7 +27,7 @@ def handle(self, *args, **options):
response_data = related_barriers_search("request_dummy",options["searchterm"],log_score=True)
logger.critical("------------")

for bar_id in response_data[-6:]:
for bar_id in response_data:#[-6:]:
if bar_id != "search_term":
logger.critical("-")
barrier_detail = Barrier.objects.get(pk=bar_id)
Expand All @@ -35,7 +36,7 @@ def handle(self, *args, **options):
logger.critical("SUMMARY = "+ str(barrier_detail.summary))
sectors_list = [get_sector(sector)["name"] for sector in barrier_detail.sectors]
logger.critical("SECTORS = "+ str(get_sector(barrier_detail.main_sector)["name"]) + " " + str(sectors_list))
logger.critical("COUNTRY = "+ str(barrier_detail.country_name))
logger.critical("COUNTRY = "+ str(get_barriers_overseas_region(barrier_detail.country, barrier_detail.trading_bloc)))
logger.critical("COMPANIES = "+ str(barrier_detail.companies))
logger.critical("OTHER COMPANIES = "+ str(barrier_detail.related_organisations))
logger.critical("STATUS SUMMARY = "+ str(barrier_detail.status_summary))
Expand Down
84 changes: 70 additions & 14 deletions api/related_barriers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.db.models import Value as V
from django.db.models.functions import Concat
from sentence_transformers import SentenceTransformer, util
import torch

from api.barriers.tasks import get_barriers_overseas_region
from api.interactions.models import Interaction
Expand Down Expand Up @@ -231,6 +232,36 @@ def get_similar_barriers_searched(

cosine_sim = util.cos_sim(new_embeddings, new_embeddings)



#embedder = SentenceTransformer("paraphrase-MiniLM-L3-v2")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
corpo_list = []
for thing in get_data():
corpo_list.append(thing["barrier_corpus"])
logger.critical("_________")
logger.critical(corpo_list)
logger.critical("_________")
corpus_embeddings = embedder.encode(corpo_list, convert_to_tensor=True)
top_k = min(5, 50)
query_embedding = embedder.encode(search_term, convert_to_tensor=True)
similarity_scores = embedder.similarity(query_embedding, corpus_embeddings)[0]
scores_2, indices = torch.topk(similarity_scores, k=top_k)
logger.critical("++++++++++++++++++++++++++++++++++++++")
logger.critical("++++++++++++++++++++++++++++++++++++++")
for score, idx in zip(scores_2, indices):
print(corpo_list[idx], "(Score: {:.4f})".format(score))
logger.critical("++++++++++++++++++++++++++++++++++++++")
logger.critical("++++++++++++++++++++++++++++++++++++++")









index = len(embeddings)

scores = cosine_sim[index]
Expand All @@ -242,9 +273,15 @@ def get_similar_barriers_searched(
if v > similarity_threshold
}

barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]

logger.critical(barrier_scores.items())
logger.critical("-")

barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])#[-quantity:]

if log_score:
logger.critical(barrier_scores)

barrier_ids = [b[0] for b in barrier_scores]
return barrier_ids

Expand Down Expand Up @@ -287,7 +324,7 @@ def get_data() -> List[Dict]:
if barrier.companies:
for company in barrier.companies:
company_affected_corpus = (
"The company called " + company["name"] + " is affected by this barrier."
"Company called " + company["name"] + " is affected by this barrier. "
)
companies_affected_list = companies_affected_list + company_affected_corpus
#companies_affected_list = companies_affected_list + str(company) + " "
Expand All @@ -296,24 +333,38 @@ def get_data() -> List[Dict]:
if barrier.related_organisations:
for company in barrier.related_organisations:
other_organisations_affected_corpus = (
"The company called " + company["name"] + " is affected by this barrier."
"Company called " + company["name"] + " is affected by this barrier. "
)
other_organisations_affected_list = other_organisations_affected_list + other_organisations_affected_corpus
#other_organisations_affected_list = other_organisations_affected_list + str(company) + " "

# Notes Tests -
#
# "Barriers that need Santino Molinaros attention"
notes_text_list = ""
for note in barrier.interactions_documents.all():
notes_text_list = notes_text_list + note.text + " "


# Sectors Tests -
#
# "Barriers that are affecting the maritime sector"
sectors_list = [get_sector(str(sector_id))["name"] for sector_id in barrier.sectors]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ""
if len(sectors_list) > 0:
for sector_name in sectors_list:
sectors_text = sectors_text + f" This barrier affects the {sector_name} sector. This barrier is related to the {sector_name} sector."
sectors_text = sectors_text + f" Affects the {sector_name} sector. Related to the {sector_name} sector."

# Overseas Region Tests -
#
# "Barriers owned by the South Asia region team"
#
# "Barriers affecting South Asia region"

overseas_region_name = get_barriers_overseas_region(barrier.country, barrier.trading_bloc)
overseas_region_text = f"This barrier belongs to the {overseas_region_name} regional team. This barrier affects the {overseas_region_name} region."
overseas_region_text = f"Belongs to the {overseas_region_name} regional team. Affects the {overseas_region_name} region. Owned by the {overseas_region_name} regional team"
#overseas_region_text = f"{overseas_region_name}"

# ERD tests -
#
Expand All @@ -339,20 +390,21 @@ def get_data() -> List[Dict]:
"%M %d %y",
"%M %d %Y"
]
for date_format in date_formats:
formatted_date = barrier.estimated_resolution_date.strftime(date_format)
#if formatted_date == "23 June 2024":
# logger.critical("======================")
# logger.critical(barrier.id)
# logger.critical("======================")
estimated_resolution_date_text = estimated_resolution_date_text + f"This barrier is estimated to be resolved on {formatted_date}. "
#estimated_resolution_date_text = str(barrier.estimated_resolution_date)
if barrier.estimated_resolution_date:
for date_format in date_formats:
formatted_date = barrier.estimated_resolution_date.strftime(date_format)
#if formatted_date == "23 June 2024":
# logger.critical("======================")
# logger.critical(barrier.id)
# logger.critical("======================")
estimated_resolution_date_text = estimated_resolution_date_text + f"Estimated to be resolved on {formatted_date}. "
#estimated_resolution_date_text = str(barrier.estimated_resolution_date)

corpus_object["id"] = barrier.id
corpus_object["barrier_corpus"] = (
f"{barrier.title}. "
f"{barrier.summary}. "
f"{sectors_text}. "
f"{sectors_text} "
f"{barrier.country_name}. "
f"{overseas_region_text}. "
f"{companies_affected_list} "
Expand All @@ -363,6 +415,10 @@ def get_data() -> List[Dict]:
f"{barrier.export_description}."
)

logger.critical("-")
logger.critical(corpus_object["barrier_corpus"])
logger.critical("-")

data_dictionary.append(corpus_object)

return data_dictionary
Expand Down
Loading

0 comments on commit 754792f

Please sign in to comment.