Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

firebreak: Remove ML dependencies and use related barriers api #838

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ __timestamp = $(shell date +%F_%H-%M)
pip-install: ## Install pip requirements inside the container.
@echo "$$(tput setaf 3)🙈 Installing Pip Packages 🙈$$(tput sgr 0)"
@docker-compose exec web poetry lock
@docker-compose exec web poetry export --without-hashes -f requirements.txt -o requirements-app.txt
@docker-compose exec web poetry export --without-hashes -f requirements.txt -o requirements.txt
@docker-compose exec web poetry export --with dev --without-hashes -f requirements.txt -o requirements-dev.txt
@docker-compose exec web pip install -r requirements-dev.txt
@docker-compose exec web sed -i '1i# ======\n# DO NOT EDIT - use pyproject.toml instead!\n# Generated: $(__timestamp)\n# ======' requirements.txt
Expand Down
18 changes: 7 additions & 11 deletions api/barriers/signals/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
send_top_priority_notification,
)
from api.metadata.constants import TOP_PRIORITY_BARRIER_STATUS
from api.related_barriers import manager
from api.related_barriers.constants import BarrierEntry
from api.related_barriers.manager import BARRIER_UPDATE_FIELDS
from api.related_barriers import client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -221,21 +219,19 @@ def related_barrier_update_embeddings(sender, instance, *args, **kwargs):

changed = any(
getattr(current_barrier_object, field) != getattr(instance, field)
for field in BARRIER_UPDATE_FIELDS
for field in ["title", "summary"]
)
logger.info(
f"(Handler) Updating related barrier embeddings for {instance.pk}: {changed}"
)

if changed and not current_barrier_object.draft:
if not manager.manager:
manager.init()
try:
manager.manager.update_barrier(
BarrierEntry(
id=str(current_barrier_object.id),
barrier_corpus=manager.barrier_to_corpus(instance),
)
# Fail gracefully
client.get_related_barriers(
pk=str(current_barrier_object.id),
title=instance.title,
summary=instance.summary,
)
except Exception as e:
# We don't want barrier embedding updates to break worker so just log error
Expand Down
118 changes: 118 additions & 0 deletions api/related_barriers/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
from typing import List, Dict

import requests

from api.barriers.tasks import get_barriers_overseas_region
from api.metadata.utils import get_sector


def get_data() -> List[Dict]:
from api.barriers.models import Barrier

data_dictionary = []
barriers = (
Barrier.objects.prefetch_related("interactions_documents")
.filter(archived=False)
.exclude(draft=True)
)
for barrier in barriers:
data_dictionary.append(
{"id": str(barrier.id), "barrier_corpus": barrier_to_corpus(barrier)}
)

return data_dictionary


def barrier_to_corpus(barrier) -> str:
companies_affected_list = ""
if barrier.companies:
companies_affected_list = ", ".join(
[company["name"] for company in barrier.companies]
)

other_organisations_affected_list = ""
if barrier.related_organisations:
other_organisations_affected_list = ", ".join(
[company["name"] for company in barrier.related_organisations]
)

notes_text_list = ", ".join(
[note.text for note in barrier.interactions_documents.all()]
)

sectors_list = [
get_sector(str(sector_id))["name"]
for sector_id in barrier.sectors
if get_sector(str(sector_id))
]
sectors_list.append(get_sector(barrier.main_sector)["name"])
sectors_text = ", ".join(sectors_list)

overseas_region_text = get_barriers_overseas_region(
barrier.country, barrier.trading_bloc
)

estimated_resolution_date_text = ""
if barrier.estimated_resolution_date:
date = barrier.estimated_resolution_date.strftime("%d-%m-%Y")
estimated_resolution_date_text = f"Estimated to be resolved on {date}."

return (
f"{barrier.title}. {barrier.summary}. "
f"{sectors_text} {barrier.country_name}. "
f"{overseas_region_text}. {companies_affected_list} "
f"{other_organisations_affected_list} {notes_text_list}. "
f"{barrier.status_summary}. {estimated_resolution_date_text}. "
f"{barrier.export_description}."
)


def seed_data():
from django.db.models import CharField
from django.db.models import Value as V
from django.db.models.functions import Concat

from api.barriers.models import Barrier

return (
Barrier.objects.filter(archived=False)
.exclude(draft=True)
.annotate(
barrier_corpus=Concat("title", V(". "), "summary", output_field=CharField())
)
.values("id", "barrier_corpus")
)


def seed():
data = get_data()
import spacy
nlp = spacy.load("en_core_web_sm")



print("Num barriers: ", len(data))
print("Example barrier: ", data[0])
lines = []
for d in data:
doc = nlp(d["barrier_corpus"])
filtered_tokens = [token.text for token in doc if not token.is_stop]
lines.append({"barrier_id": str(d["id"]), "corpus": " ".join(filtered_tokens)})

print(lines[0])

requests.post(
f"{os.environ['RELATED_BARRIERS_BASE_URL']}/seed",
json={
"data": lines
},
)


def get_related_barriers(pk, title, summary):
res = requests.post(
f"{os.environ['RELATED_BARRIERS_BASE_URL']}/related-barriers",
json={"barrier_id": f"{pk}", "corpus": f"{title}. {summary}"},
)
return res.json()["results"]
9 changes: 0 additions & 9 deletions api/related_barriers/constants.py

This file was deleted.

Empty file.
87 changes: 87 additions & 0 deletions api/related_barriers/management/commands/related_barriers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import time

import requests
from django.core.management import BaseCommand
from api.related_barriers import client
from api.barriers.models import Barrier


class Command(BaseCommand):
help = "Attach policy teams to barriers"

def add_arguments(self, parser):
parser.add_argument("--seed", type=bool, help="Seed related barriers service")
parser.add_argument("--test", type=bool, help="Seed related barriers service")

def handle(self, *args, **options):
seed = options["seed"]
test = options["test"]

if seed:
print('Seeding related barriers service ...')
start = time.time()
client.seed()
end = time.time()
print(f'Seeded in {end - start}s')

if test:
qs = Barrier.objects.order_by("-created_on")

for b in qs[10:43]:
print(b.title)

import csv

entries = []
cleaned_entries = []


import spacy
nlp = spacy.load("en_core_web_sm")

with open('/usr/src/app/api/related_barriers/management/commands/data.csv') as csvfile:
spamreader = csv.reader(csvfile, delimiter='#')
for barrier, row in zip(qs[10:43], spamreader):
print(barrier.id, len(row))
raw_text = f"{row[0]} . {row[1]}"
entries.append(
{
"barrier_id": str(barrier.id),
"corpus": raw_text
}
)
cleaned_entries.append(
{
"barrier_id": str(barrier.id),
"corpus": " ".join([t.text for t in nlp(raw_text) if not t.is_stop])
}
)


requests.post(
f"{os.environ['RELATED_BARRIERS_BASE_URL']}/seed",
json={
"data": entries
},
)
res = requests.get(
f"{os.environ['RELATED_BARRIERS_BASE_URL']}/cosine-similarity",
)

print(res.json())

requests.post(
f"{os.environ['RELATED_BARRIERS_BASE_URL']}/seed",
json={
"data": cleaned_entries
},
)
res2 = requests.get(
f"{os.environ['RELATED_BARRIERS_BASE_URL']}/cosine-similarity",
)

print(res2.json())

print(entries[0])
print(cleaned_entries[0])
Loading