Skip to content

Commit

Permalink
Update and working search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
abarolo authored and abarolo committed Aug 7, 2024
1 parent 754792f commit c090418
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 159 deletions.
182 changes: 32 additions & 150 deletions api/related_barriers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def wrapper(*args, **kwargs):

@timing
def get_transformer():
# SentenceTransformer("all-MiniLM-L6-v2")
return SentenceTransformer("paraphrase-MiniLM-L3-v2")


Expand Down Expand Up @@ -77,6 +78,7 @@ def set_data(self, data: List[Dict]):
"""
Load data into memory.
"""
print('Setting data')
logger.info("(Related Barriers): set_data")
self.flush()
barrier_ids = [str(d["id"]) for d in data]
Expand Down Expand Up @@ -210,17 +212,14 @@ def get_similar_barriers(

@timing
def get_similar_barriers_searched(
self, search_term: str, similarity_threshold: float, quantity: int, log_score: bool,
self, search_term: str, similarity_threshold: float, quantity: int=None, log_score=None
):

logger.info(f"(Related Barriers): get_similar_barriers_seaarched")
barrier_ids = self.get_barrier_ids()

# RESTORE THESE 2 LINES BEFORE COMMITTING - CACHING NEEDS TO GO FOR TESTING PURPOSESES
#if not barrier_ids:
# self.set_data(get_data())
# DELETE THIS LINE BEFORE COMMITTING - CACHING SHOULD COME BACK FOR REAL
self.set_data(get_data())
if not barrier_ids:
self.set_data(get_data())

barrier_ids = self.get_barrier_ids()

Expand All @@ -232,58 +231,16 @@ def get_similar_barriers_searched(

cosine_sim = util.cos_sim(new_embeddings, new_embeddings)



#embedder = SentenceTransformer("paraphrase-MiniLM-L3-v2")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
corpo_list = []
for thing in get_data():
corpo_list.append(thing["barrier_corpus"])
logger.critical("_________")
logger.critical(corpo_list)
logger.critical("_________")
corpus_embeddings = embedder.encode(corpo_list, convert_to_tensor=True)
top_k = min(5, 50)
query_embedding = embedder.encode(search_term, convert_to_tensor=True)
similarity_scores = embedder.similarity(query_embedding, corpus_embeddings)[0]
scores_2, indices = torch.topk(similarity_scores, k=top_k)
logger.critical("++++++++++++++++++++++++++++++++++++++")
logger.critical("++++++++++++++++++++++++++++++++++++++")
for score, idx in zip(scores_2, indices):
print(corpo_list[idx], "(Score: {:.4f})".format(score))
logger.critical("++++++++++++++++++++++++++++++++++++++")
logger.critical("++++++++++++++++++++++++++++++++++++++")









index = len(embeddings)

scores = cosine_sim[index]
index_search_term_embedding = len(new_embeddings) - 1
scores = cosine_sim[index_search_term_embedding]
barrier_scores = dict(zip(new_barrier_ids, scores))

barrier_scores = {
k: v
for k, v in barrier_scores.items()
if v > similarity_threshold
if v > similarity_threshold and k != embedded_index
}


logger.critical(barrier_scores.items())
logger.critical("-")

barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])#[-quantity:]

if log_score:
logger.critical(barrier_scores)

barrier_ids = [b[0] for b in barrier_scores]
return barrier_ids
barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]
return barrier_scores


manager: Optional[RelatedBarrierManager] = None
Expand All @@ -292,113 +249,40 @@ def get_similar_barriers_searched(
def get_data() -> List[Dict]:
from api.barriers.models import Barrier

# ./manage.py run_test
return [
{"id": barrier.id, "barrier_corpus": f"{barrier.title} . {barrier.summary}"}
for barrier in Barrier.objects.filter(archived=False).exclude(draft=True)
]


def get_data_2() -> List[Dict]:
from api.barriers.models import Barrier

data_dictionary = []
barriers = Barrier.objects.filter(archived=False).exclude(draft=True)
for barrier in barriers:
corpus_object = {}

# Companies tests -
#
# "Barriers that affect barclays bank"
# WITH Additional sentence text
#('dde42f8f-b02b-41a6-b794-3cb6e1cdade5', tensor(0.3390))
# WITHOUT Additional sentence text
#('dde42f8f-b02b-41a6-b794-3cb6e1cdade5', tensor(0.2761))
# (Barrier with Barclays PLC not the highest match)
#
# "Barclays bank"
# No results - one word match results in too low a threshold for result
#
# "show me barriers that are related to Barclays bank"
# WITH Additional sentence text
# ('dde42f8f-b02b-41a6-b794-3cb6e1cdade5', tensor(0.4561))
# barrier with Barlays under companies comes 2nd in matching - the highest
# has a export description of "Deal account account tough" - which could be words
# related to the query word "bank"
# WITHOUT Additional sentence text
# Expected barrier did not appear

companies_affected_list = ""
if barrier.companies:
for company in barrier.companies:
company_affected_corpus = (
"Company called " + company["name"] + " is affected by this barrier. "
)
companies_affected_list = companies_affected_list + company_affected_corpus
#companies_affected_list = companies_affected_list + str(company) + " "
companies_affected_list = ", ".join([company['name'] for company in barrier.companies])

other_organisations_affected_list = ""
if barrier.related_organisations:
for company in barrier.related_organisations:
other_organisations_affected_corpus = (
"Company called " + company["name"] + " is affected by this barrier. "
)
other_organisations_affected_list = other_organisations_affected_list + other_organisations_affected_corpus
#other_organisations_affected_list = other_organisations_affected_list + str(company) + " "

# Notes Tests -
#
# "Barriers that need Santino Molinaros attention"
notes_text_list = ""
for note in barrier.interactions_documents.all():
notes_text_list = notes_text_list + note.text + " "


# Sectors Tests -
#
# "Barriers that are affecting the maritime sector"
other_organisations_affected_list = ", ".join([company['name'] for company in barrier.related_organisations])

notes_text_list = ", ".join([note.text for note in barrier.interactions_documents.all()])

sectors_list = [get_sector(str(sector_id))["name"] for sector_id in barrier.sectors]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ""
if len(sectors_list) > 0:
for sector_name in sectors_list:
sectors_text = sectors_text + f" Affects the {sector_name} sector. Related to the {sector_name} sector."

# Overseas Region Tests -
#
# "Barriers owned by the South Asia region team"
#
# "Barriers affecting South Asia region"

overseas_region_name = get_barriers_overseas_region(barrier.country, barrier.trading_bloc)
overseas_region_text = f"Belongs to the {overseas_region_name} regional team. Affects the {overseas_region_name} region. Owned by the {overseas_region_name} regional team"
#overseas_region_text = f"{overseas_region_name}"

# ERD tests -
#
# "Barriers which are estimated to be resolved on 23 June 2024"
# Expected result not returned, but did get results back
#
# "barrier is estimated to be resolved on 23 June 2024"
# Expected result not returned, but did get results back

estimated_resolution_date_text = ""
date_formats = [
"%d-%m-%Y",
"%d/%m/%Y",
"%d %m %Y",
"%d-%m-%y",
"%d/%m/%y",
"%d %m %y",
"%d %b %Y",
"%d %B %Y",
"%x",
"%M %y",
"%M %Y",
"%M %d %y",
"%M %d %Y"
]
sectors_text = ", ".join(sectors_list)

overseas_region_text = get_barriers_overseas_region(barrier.country, barrier.trading_bloc)

estimated_resolution_date_text = ''
if barrier.estimated_resolution_date:
for date_format in date_formats:
formatted_date = barrier.estimated_resolution_date.strftime(date_format)
#if formatted_date == "23 June 2024":
# logger.critical("======================")
# logger.critical(barrier.id)
# logger.critical("======================")
estimated_resolution_date_text = estimated_resolution_date_text + f"Estimated to be resolved on {formatted_date}. "
#estimated_resolution_date_text = str(barrier.estimated_resolution_date)
date = barrier.estimated_resolution_date.strftime("%d-%m-%Y")
estimated_resolution_date_text = f"Estimated to be resolved on {date}."

corpus_object["id"] = barrier.id
corpus_object["barrier_corpus"] = (
Expand All @@ -415,21 +299,19 @@ def get_data() -> List[Dict]:
f"{barrier.export_description}."
)

logger.critical("-")
logger.critical(corpus_object["barrier_corpus"])
logger.critical("-")

data_dictionary.append(corpus_object)

return data_dictionary


def init():
global manager

if manager:
raise Exception("Related Barrier Manager already set")

manager = RelatedBarrierManager()

if not manager.get_barrier_ids():
logger.info("(Related Barriers): Initialising)")
data = get_data()
Expand Down
4 changes: 4 additions & 0 deletions api/related_barriers/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class BarrierRelatedListSerializer(serializers.Serializer):
status = StatusField(required=False)
location = serializers.CharField(read_only=True)
similarity = serializers.FloatField(read_only=True)


class SearchRequest(serializers.Serializer):
search_term = serializers.CharField(required=True)
2 changes: 1 addition & 1 deletion api/related_barriers/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
"barriers/<uuid:pk>/related-barriers", related_barriers, name="related-barriers"
),
path(
"barriers/dashboard/<str:search_term>", related_barriers_search, name="related-barriers-search"
"related-barriers", related_barriers_search, name="related-barriers-search"
),
]
19 changes: 11 additions & 8 deletions api/related_barriers/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SIMILARITY_THRESHOLD,
BarrierEntry,
)
from api.related_barriers.serializers import BarrierRelatedListSerializer
from api.related_barriers.serializers import BarrierRelatedListSerializer, SearchRequest

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,26 +42,29 @@ def related_barriers(request, pk) -> Response:
).data
)

#@api_view(["GET"])
def related_barriers_search(request, search_term, log_score) -> Response:

@api_view(["GET"])
def related_barriers_search(request) -> Response:
"""
Return a list of related barriers given a search term/phrase
"""

serializer = SearchRequest(data=request.GET)
serializer.is_valid(raise_exception=True)

if manager.manager is None:
manager.init()

similar_barrier_ids = manager.manager.get_similar_barriers_searched(
search_term=search_term,
barrier_scores = manager.manager.get_similar_barriers_searched(
search_term=serializer.data['search_term'],
similarity_threshold=SIMILARITY_THRESHOLD,
quantity=SIMILAR_BARRIERS_LIMIT,
log_score=log_score
)

return similar_barrier_ids
barrier_ids = [b[0] for b in barrier_scores]

return Response(
BarrierRelatedListSerializer(
Barrier.objects.filter(id__in=similar_barrier_ids), many=True
Barrier.objects.filter(id__in=barrier_ids), many=True
).data
)

0 comments on commit c090418

Please sign in to comment.