Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Related Barriers #906

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions api/barriers/signals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
send_top_priority_notification,
)
from api.metadata.constants import TOP_PRIORITY_BARRIER_STATUS
from api.related_barriers import manager
from api.related_barriers.constants import BarrierEntry
from api.related_barriers.manager import BARRIER_UPDATE_FIELDS
from api.related_barriers.tasks import update_related_barrier

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -228,15 +227,4 @@ def related_barrier_update_embeddings(sender, instance, *args, **kwargs):
)

if changed and not current_barrier_object.draft:
if not manager.manager:
manager.init()
try:
manager.manager.update_barrier(
BarrierEntry(
id=str(current_barrier_object.id),
barrier_corpus=manager.barrier_to_corpus(instance),
)
)
except Exception as e:
# We don't want barrier embedding updates to break worker so just log error
logger.critical(str(e))
update_related_barrier(barrier_id=str(instance.pk))
45 changes: 45 additions & 0 deletions api/related_barriers/management/commands/related_barriers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import time

from django.core.management import BaseCommand

from api.related_barriers import manager


class Command(BaseCommand):
help = "Attach policy teams to barriers"

def add_arguments(self, parser):
parser.add_argument("--flush", type=bool, help="Flush cache")
parser.add_argument("--reindex", type=bool, help="Reindex cache")
parser.add_argument(
"--stats", type=bool, help="Get Related Barrier cache stats"
)

def handle(self, *args, **options):
flush = options["flush"]
reindex = options["reindex"]
stats = options["stats"]

rb_manager = manager.RelatedBarrierManager()

barrier_count = len(rb_manager.get_barrier_ids())

if stats:
print(f"Barrier Count: {barrier_count}")

if reindex:
print(f"Reindexing {barrier_count} barriers")
data = manager.get_data()

rb_manager.flush()

s = time.time()
rb_manager.set_data(data)
end = time.time()
print("Training time: ", end - s)
return

if flush:
print("Flushing")
rb_manager.flush()
return
121 changes: 56 additions & 65 deletions api/related_barriers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Optional

import numpy
import torch
from django.core.cache import cache
from sentence_transformers import SentenceTransformer, util

Expand All @@ -29,8 +30,7 @@ def wrapper(*args, **kwargs):

@timing
def get_transformer():
# SentenceTransformer("all-MiniLM-L6-v2")
return SentenceTransformer("paraphrase-MiniLM-L3-v2")
return SentenceTransformer("all-MiniLM-L6-v2")


BARRIER_UPDATE_FIELDS: List[str] = ["title", "summary"]
Expand Down Expand Up @@ -164,7 +164,7 @@ def remove_barrier(self, barrier: BarrierEntry, barrier_ids=None) -> None:
@timing
def update_barrier(self, barrier: BarrierEntry) -> None:
logger.info(f"(Related Barriers): update_barrier {barrier.id}")
barrier_ids = manager.get_barrier_ids()
barrier_ids = self.get_barrier_ids()
if barrier.id in barrier_ids:
self.remove_barrier(barrier, barrier_ids)
self.add_barrier(barrier, barrier_ids)
Expand Down Expand Up @@ -200,7 +200,8 @@ def get_similar_barriers(
for k, v in barrier_scores.items()
if v > similarity_threshold and k != barrier.id
}
return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]
results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]
return [(r[0], torch.round(torch.tensor(r[1]), decimals=4)) for r in results]

@timing
def get_similar_barriers_searched(
Expand All @@ -222,7 +223,7 @@ def get_similar_barriers_searched(
logger.info("(Related Barriers): get_similar_barriers_searched")

if not (barrier_ids := self.get_barrier_ids()):
self.set_data(get_data_2())
self.set_data(get_data())
barrier_ids = self.get_barrier_ids() or []

if not barrier_ids:
Expand All @@ -249,7 +250,8 @@ def get_similar_barriers_searched(
for k, v in barrier_scores.items()
if v > similarity_threshold and k != embedded_index
}
return sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]
results = sorted(barrier_scores.items(), key=lambda x: x[1])[-quantity:]
return [(r[0], torch.round(r[1], decimals=4)) for r in results]


manager: Optional[RelatedBarrierManager] = None
Expand All @@ -258,68 +260,17 @@ def get_similar_barriers_searched(
def get_data() -> List[Dict]:
from api.barriers.models import Barrier

return [
{"id": barrier.id, "barrier_corpus": f"{barrier.title} . {barrier.summary}"}
for barrier in Barrier.objects.filter(archived=False).exclude(draft=True)
]


def get_data_2() -> List[Dict]:
from api.barriers.models import Barrier

data_dictionary = []
barriers = Barrier.objects.filter(archived=False).exclude(draft=True)
barriers = (
Barrier.objects.prefetch_related("interactions_documents")
.filter(archived=False)
.exclude(draft=True)
)
for barrier in barriers:
corpus_object = {}

companies_affected_list = ""
if barrier.companies:
companies_affected_list = ", ".join(
[company["name"] for company in barrier.companies]
)

other_organisations_affected_list = ""
if barrier.related_organisations:
other_organisations_affected_list = ", ".join(
[company["name"] for company in barrier.related_organisations]
)

notes_text_list = ", ".join(
[note.text for note in barrier.interactions_documents.all()]
)

sectors_list = [
get_sector(str(sector_id))["name"] for sector_id in barrier.sectors
]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ", ".join(sectors_list)

overseas_region_text = get_barriers_overseas_region(
barrier.country, barrier.trading_bloc
data_dictionary.append(
{"id": str(barrier.id), "barrier_corpus": barrier_to_corpus(barrier)}
)

estimated_resolution_date_text = ""
if barrier.estimated_resolution_date:
date = barrier.estimated_resolution_date.strftime("%d-%m-%Y")
estimated_resolution_date_text = f"Estimated to be resolved on {date}."

corpus_object["id"] = barrier.id
corpus_object["barrier_corpus"] = (
f"{barrier.title}. "
f"{barrier.summary}. "
f"{sectors_text} "
f"{barrier.country_name}. "
f"{overseas_region_text}. "
f"{companies_affected_list} "
f"{other_organisations_affected_list} "
f"{notes_text_list}. "
f"{barrier.status_summary}. "
f"{estimated_resolution_date_text}. "
f"{barrier.export_description}."
)

data_dictionary.append(corpus_object)

return data_dictionary


Expand All @@ -338,4 +289,44 @@ def init():


def barrier_to_corpus(barrier) -> str:
return barrier.title + ". " + barrier.summary
companies_affected_list = ""
if barrier.companies:
companies_affected_list = ", ".join(
[company["name"] for company in barrier.companies]
)

other_organisations_affected_list = ""
if barrier.related_organisations:
other_organisations_affected_list = ", ".join(
[company["name"] for company in barrier.related_organisations]
)

notes_text_list = ", ".join(
[note.text for note in barrier.interactions_documents.all()]
)

sectors_list = [
get_sector(str(sector_id))["name"]
for sector_id in barrier.sectors
if get_sector(str(sector_id))
]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ", ".join(sectors_list)

overseas_region_text = get_barriers_overseas_region(
barrier.country, barrier.trading_bloc
)

estimated_resolution_date_text = ""
if barrier.estimated_resolution_date:
date = barrier.estimated_resolution_date.strftime("%d-%m-%Y")
estimated_resolution_date_text = f"Estimated to be resolved on {date}."

return (
f"{barrier.title}. {barrier.summary}. "
f"{sectors_text} {barrier.country_name}. "
f"{overseas_region_text}. {companies_affected_list} "
f"{other_organisations_affected_list} {notes_text_list}. "
f"{barrier.status_summary}. {estimated_resolution_date_text}. "
f"{barrier.export_description}."
)
34 changes: 34 additions & 0 deletions api/related_barriers/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging

from celery import shared_task
from django.core.management import call_command

from api.barriers.models import Barrier
from api.related_barriers import manager
from api.related_barriers.constants import BarrierEntry

logger = logging.getLogger(__name__)


def update_related_barrier(barrier_id: str):
logger.info(f"Updating related barrier embeddings for: {barrier_id}")
if not manager.manager:
manager.init()
try:
manager.manager.update_barrier(
BarrierEntry(
id=barrier_id,
barrier_corpus=manager.barrier_to_corpus(
Barrier.objects.get(pk=barrier_id)
),
)
)
except Exception as e:
# We don't want barrier embedding updates to break worker so just log error
logger.critical(str(e))


@shared_task
def reindex_related_barriers():
"""Schedule daily task to reindex"""
call_command("related_barriers", "--reindex", "1")
6 changes: 6 additions & 0 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@
CELERY_BEAT_SCHEDULE = {}

if not DEBUG:
# Runs daily at midnight
CELERY_BEAT_SCHEDULE["reindex_related_barriers"] = {
"task": "api.related_barriers.tasks.reindex_related_barriers",
"schedule": crontab(minute=0, hour=0),
}

# Runs daily at 6am
CELERY_BEAT_SCHEDULE["send_notification_emails"] = {
"task": "api.user.tasks.send_notification_emails",
Expand Down
6 changes: 4 additions & 2 deletions tests/barriers/signals/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
def test_related_barrier_handler():
barrier = BarrierFactory()

with mock.patch("api.barriers.signals.handlers.manager") as mock_manager:
with mock.patch(
"api.barriers.signals.handlers.update_related_barrier"
) as mock_update_related_barrier:
barrier.title = "New Title"
barrier.save()

assert mock_manager.manager.update_barrier.call_count == 1
mock_update_related_barrier.assert_called_once_with(barrier_id=str(barrier.pk))
Loading