Skip to content

Commit

Permalink
Adding free-text-search to related barriers (#845)
Browse files Browse the repository at this point in the history
* Adding free-text-search to related barriers

* Rebase

* Update and working search endpoint

* format

* format

* Excluding archived barriers from inactive notification alerts (#858)

Co-authored-by: santinomolinaro <santino.molinaro@mail>

* Deleting cached barrier when creating and updating next steps items so they are reflected instantly on frontend (#861)

Co-authored-by: santinomolinaro <santino.molinaro@mail>

* TSS-1995: Move set_to_allowed_on into public barrier serializer (#862)

* tss-1995: set_to_allowed expected on public_barrier serializer

---------

Co-authored-by: abarolo <[email protected]>

* updates

* format

* format

* using the full text search on the search page

* lint fix

---------

Co-authored-by: santinomolinaro <santino.molinaro@mail>
Co-authored-by: abarolo <[email protected]>
Co-authored-by: abarolo <[email protected]>
Co-authored-by: Uka Osim <[email protected]>
Co-authored-by: Uka Osim <[email protected]>
  • Loading branch information
6 people authored Aug 27, 2024
1 parent e7f2972 commit e8b9f2a
Show file tree
Hide file tree
Showing 11 changed files with 526 additions and 186 deletions.
42 changes: 39 additions & 3 deletions api/barriers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from django.core.cache import cache
from django.core.validators import int_list_validator
from django.db import models
from django.db.models import CASCADE, Q, QuerySet
from django.db.models import CASCADE, CharField, Q, QuerySet
from django.db.models.functions import Cast
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django_filters.widgets import BooleanWidget
Expand Down Expand Up @@ -1279,8 +1280,8 @@ class BarrierFilterSet(django_filters.FilterSet):
combined_priority = django_filters.BaseInFilter(method="combined_priority_filter")
location = django_filters.BaseInFilter(method="location_filter")
admin_areas = django_filters.BaseInFilter(method="admin_areas_filter")
search = django_filters.Filter(method="text_search")
text = django_filters.Filter(method="text_search")
search = django_filters.Filter(method="vector_search")
text = django_filters.Filter(method="vector_search")

user = django_filters.Filter(method="my_barriers")
has_action_plan = django_filters.Filter(method="has_action_plan_filter")
Expand Down Expand Up @@ -1592,6 +1593,41 @@ def text_search(self, queryset, name, value):
| Q(combined_company_related_organisation_query)
)

def vector_search(self, queryset, name, value):
"""
full text search on multiple fields
Args:
queryset the queryset to filter
name the name of the filter
value the value of the filter
Returns:
_type_: the resust is a queryset with a free text search applied to the barrier model
"""

from api.related_barriers import manager as handler_manager
from api.related_barriers.constants import (
SIMILAR_BARRIERS_LIMIT,
SIMILARITY_THRESHOLD,
)

if handler_manager.manager is None:
handler_manager.init()

barrier_scores = handler_manager.manager.get_similar_barriers_searched(
search_term=value,
similarity_threshold=SIMILARITY_THRESHOLD,
quantity=SIMILAR_BARRIERS_LIMIT,
)

barrier_ids = [b[0] for b in barrier_scores]

queryset = queryset.filter(id__in=barrier_ids).annotate(
barrier_id=Cast("id", output_field=CharField())
)

return queryset

def my_barriers(self, queryset, name, value):
if value:
current_user = self.get_user()
Expand Down
2 changes: 1 addition & 1 deletion api/barriers/signals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def related_barrier_update_embeddings(sender, instance, *args, **kwargs):
manager.manager.update_barrier(
BarrierEntry(
id=str(current_barrier_object.id),
barrier_corpus=manager.barrier_to_corpus(current_barrier_object),
barrier_corpus=manager.barrier_to_corpus(instance),
)
)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging

from django.core.management import BaseCommand

from api.barriers.models import Barrier
from api.barriers.tasks import get_barriers_overseas_region
from api.metadata.utils import get_sector
from api.related_barriers.views import related_barriers_search

logger = logging.getLogger(__name__)


class Command(BaseCommand):

# ./manage.py run_search_related_test --searchterm 'Company called BARCLAYS PLC is affected by this barrier'

help = "Test the related barriers engine with a search term"

def add_arguments(self, parser):
parser.add_argument("--searchterm", type=str, help="Search term to use")

def handle(self, *args, **options):
logger.critical("+++++++++")
logger.critical("Start Test")
logger.critical(options["searchterm"])
logger.critical("+++++++++")

logger.critical("SCORES -----")
response_data = related_barriers_search(
"request_dummy", options["searchterm"], log_score=True
)
logger.critical("------------")

for bar_id in response_data: # [-6:]:
if bar_id != "search_term":
logger.critical("-")
barrier_detail = Barrier.objects.get(pk=bar_id)
logger.critical("ID = " + str(bar_id))
logger.critical("TITLE = " + str(barrier_detail.title))
logger.critical("SUMMARY = " + str(barrier_detail.summary))
sectors_list = [
get_sector(sector)["name"] for sector in barrier_detail.sectors
]
logger.critical(
"SECTORS = "
+ str(get_sector(barrier_detail.main_sector)["name"])
+ " "
+ str(sectors_list)
)
logger.critical(
"COUNTRY = "
+ str(
get_barriers_overseas_region(
barrier_detail.country, barrier_detail.trading_bloc
)
)
)
logger.critical("COMPANIES = " + str(barrier_detail.companies))
logger.critical(
"OTHER COMPANIES = " + str(barrier_detail.related_organisations)
)
logger.critical(
"STATUS SUMMARY = " + str(barrier_detail.status_summary)
)
logger.critical(
"EST. RESOLUTION DATE = "
+ str(barrier_detail.estimated_resolution_date)
)
logger.critical(
"EXPORT DESC. = " + str(barrier_detail.export_description)
)
notes_list = [
note.text for note in barrier_detail.interactions_documents.all()
]
logger.critical("NOTES = " + str(notes_list))
logger.critical("-")

logger.critical("+++++++++")
logger.critical("End Test")
logger.critical("+++++++++")
118 changes: 105 additions & 13 deletions api/related_barriers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

import numpy
from django.core.cache import cache
from django.db.models import CharField
from django.db.models import Value as V
from django.db.models.functions import Concat
from sentence_transformers import SentenceTransformer, util

from api.barriers.tasks import get_barriers_overseas_region
from api.metadata.utils import get_sector
from api.related_barriers.constants import BarrierEntry

logger = logging.getLogger(__name__)
Expand All @@ -30,6 +29,7 @@ def wrapper(*args, **kwargs):

@timing
def get_transformer():
# SentenceTransformer("all-MiniLM-L6-v2")
return SentenceTransformer("paraphrase-MiniLM-L3-v2")


Expand Down Expand Up @@ -197,10 +197,46 @@ def get_similar_barriers(
for k, v in barrier_scores.items()
if v > similarity_threshold and k != barrier.id
}
barrier_scores = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]
return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]

barrier_ids = [b[0] for b in barrier_scores]
return barrier_ids
@timing
def get_similar_barriers_searched(
self,
search_term: str,
similarity_threshold: float,
quantity: int = None,
log_score=None,
):

logger.info("(Related Barriers): get_similar_barriers_searched")
barrier_ids = self.get_barrier_ids()

if not barrier_ids:
self.set_data(get_data())

barrier_ids = self.get_barrier_ids()

embedded_index = "search_term"
search_term_embedding = self.model.encode(
search_term, convert_to_tensor=True
).numpy()
embeddings = self.get_embeddings()
new_embeddings = numpy.vstack(
[embeddings, search_term_embedding]
) # append embedding
new_barrier_ids = barrier_ids + [embedded_index] # append barrier_id

cosine_sim = util.cos_sim(new_embeddings, new_embeddings)

index_search_term_embedding = len(new_embeddings) - 1
scores = cosine_sim[index_search_term_embedding]
barrier_scores = dict(zip(new_barrier_ids, scores))
barrier_scores = {
k: v
for k, v in barrier_scores.items()
if v > similarity_threshold and k != embedded_index
}
return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]


manager: Optional[RelatedBarrierManager] = None
Expand All @@ -209,14 +245,69 @@ def get_similar_barriers(
def get_data() -> List[Dict]:
from api.barriers.models import Barrier

return (
Barrier.objects.filter(archived=False)
.exclude(draft=True)
.annotate(
barrier_corpus=Concat("title", V(". "), "summary", output_field=CharField())
return [
{"id": barrier.id, "barrier_corpus": f"{barrier.title} . {barrier.summary}"}
for barrier in Barrier.objects.filter(archived=False).exclude(draft=True)
]


def get_data_2() -> List[Dict]:
from api.barriers.models import Barrier

data_dictionary = []
barriers = Barrier.objects.filter(archived=False).exclude(draft=True)
for barrier in barriers:
corpus_object = {}

companies_affected_list = ""
if barrier.companies:
companies_affected_list = ", ".join(
[company["name"] for company in barrier.companies]
)

other_organisations_affected_list = ""
if barrier.related_organisations:
other_organisations_affected_list = ", ".join(
[company["name"] for company in barrier.related_organisations]
)

notes_text_list = ", ".join(
[note.text for note in barrier.interactions_documents.all()]
)

sectors_list = [
get_sector(str(sector_id))["name"] for sector_id in barrier.sectors
]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ", ".join(sectors_list)

overseas_region_text = get_barriers_overseas_region(
barrier.country, barrier.trading_bloc
)

estimated_resolution_date_text = ""
if barrier.estimated_resolution_date:
date = barrier.estimated_resolution_date.strftime("%d-%m-%Y")
estimated_resolution_date_text = f"Estimated to be resolved on {date}."

corpus_object["id"] = barrier.id
corpus_object["barrier_corpus"] = (
f"{barrier.title}. "
f"{barrier.summary}. "
f"{sectors_text} "
f"{barrier.country_name}. "
f"{overseas_region_text}. "
f"{companies_affected_list} "
f"{other_organisations_affected_list} "
f"{notes_text_list}. "
f"{barrier.status_summary}. "
f"{estimated_resolution_date_text}. "
f"{barrier.export_description}."
)
.values("id", "barrier_corpus")
)

data_dictionary.append(corpus_object)

return data_dictionary


def init():
Expand All @@ -226,6 +317,7 @@ def init():
raise Exception("Related Barrier Manager already set")

manager = RelatedBarrierManager()

if not manager.get_barrier_ids():
logger.info("(Related Barriers): Initialising)")
data = get_data()
Expand Down
27 changes: 18 additions & 9 deletions api/related_barriers/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@
from api.barriers.fields import StatusField


# TODO : standard list serialiser may suffice and the following not required base on final designs
class BarrierRelatedListSerializer(serializers.Serializer):
summary = serializers.CharField(read_only=True)
title = serializers.CharField(read_only=True)
id = serializers.UUIDField(read_only=True)
reported_on = serializers.DateTimeField(read_only=True)
modified_on = serializers.DateTimeField(read_only=True)
status = StatusField(required=False)
location = serializers.CharField(read_only=True)
similarity = serializers.FloatField(read_only=True)
summary = serializers.CharField()
title = serializers.CharField()
barrier_id = serializers.CharField()
reported_on = serializers.DateTimeField()
modified_on = serializers.DateTimeField()
status = StatusField()
location = serializers.SerializerMethodField()
similarity = serializers.FloatField()

def get_location(self, obj):
try:
return obj.location or ""
except Exception:
return ""


class SearchRequest(serializers.Serializer):
search_term = serializers.CharField(required=True)
3 changes: 2 additions & 1 deletion api/related_barriers/urls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.urls import path
from rest_framework.routers import DefaultRouter

from api.related_barriers.views import related_barriers
from api.related_barriers.views import related_barriers, related_barriers_search

app_name = "related_barriers"

Expand All @@ -12,4 +12,5 @@
path(
"barriers/<uuid:pk>/related-barriers", related_barriers, name="related-barriers"
),
path("related-barriers", related_barriers_search, name="related-barriers-search"),
]
Loading

0 comments on commit e8b9f2a

Please sign in to comment.