Skip to content

Commit

Permalink
feat(search): Adds task to send email with search results as attachment
Browse files Browse the repository at this point in the history
  • Loading branch information
ERosendo committed Jan 16, 2025
1 parent 93d31a2 commit d4b2f1a
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 17 deletions.
108 changes: 108 additions & 0 deletions cl/lib/search_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import pickle
import re
from typing import Any, Dict, List, Optional, Tuple, TypedDict
Expand All @@ -6,6 +7,7 @@
from asgiref.sync import async_to_sync, sync_to_async
from django.conf import settings
from django.core.cache import cache
from django.core.exceptions import PermissionDenied
from django.core.paginator import EmptyPage, Page, PageNotAnInteger
from django.http import HttpRequest
from django.http.request import QueryDict
Expand Down Expand Up @@ -38,7 +40,9 @@
from cl.search.documents import (
AudioDocument,
DocketDocument,
ESRECAPDocument,
OpinionClusterDocument,
OpinionDocument,
ParentheticalGroupDocument,
PersonDocument,
)
Expand All @@ -59,6 +63,19 @@

HYPERSCAN_TOKENIZER = HyperscanTokenizer(cache_dir=".hyperscan")

logger = logging.getLogger(__name__)


def check_pagination_depth(page_number):
"""Check if the pagination is too deep (indicating a crawler)"""

if page_number > settings.MAX_SEARCH_PAGINATION_DEPTH:
logger.warning(
"Query depth of %s denied access (probably a crawler)",
page_number,
)
raise PermissionDenied


def make_get_string(
request: HttpRequest,
Expand Down Expand Up @@ -620,3 +637,94 @@ def do_es_search(
),
"missing_citations": missing_citations_str,
}


def get_headers_for_search_export(type: str) -> list[str]:
"""Creates a list of headers suitable for CSV export of search results.
:param type: The type of Elasticsearch search to be performed. Valid values
are defined in the `SEARCH_TYPES` enum.
:return: A list of strings representing the CSV headers.
"""
match type:
case SEARCH_TYPES.PEOPLE:
keys = PersonDocument.__dict__["_fields"].keys()
case SEARCH_TYPES.ORAL_ARGUMENT:
keys = AudioDocument.__dict__["_fields"].keys()
case SEARCH_TYPES.PARENTHETICAL:
keys = ParentheticalGroupDocument.__dict__["_fields"].keys()
case SEARCH_TYPES.RECAP:
keys = set(
[
*DocketDocument.__dict__["_fields"].keys(),
*ESRECAPDocument.__dict__["_fields"].keys(),
]
)
case SEARCH_TYPES.OPINION:
keys = set(
[
*OpinionClusterDocument.__dict__["_fields"].keys(),
*OpinionDocument.__dict__["_fields"].keys(),
]
)

return [
key
for key in keys
if key not in ("person_child", "docket_child", "cluster_child")
]


def fetch_es_results_for_csv(
queryset: QueryDict, search_type: str
) -> list[dict[str, Any]]:
"""Retrieves matching results from Elasticsearch and returns them as a list
This method will flatten nested results (like those returned by opinion and
recap searches) and limit the number of results in the list to
`settings.MAX_SEARCH_RESULTS_EXPORTED`.
:param queryset: The query parameters sent by the user.
:param search_type: The type of Elasticsearch search to be performed.
:return: A list of dictionaries, where each dictionary represents a single
search result.
"""
csv_rows: list[dict[str, Any]] = []
while len(csv_rows) <= settings.MAX_SEARCH_RESULTS_EXPORTED:
search = do_es_search(
queryset, rows=settings.MAX_SEARCH_RESULTS_EXPORTED
)
if search["error"]:
return csv_rows

results = search["results"]
match search_type:
case SEARCH_TYPES.OPINION | SEARCH_TYPES.RECAP:
flat_results = []
for result in results.object_list:
parent_dict = result.to_dict()
child_docs = parent_dict.pop("child_docs")
if child_docs:
flat_results.extend(
[
parent_dict | doc["_source"].to_dict()
for doc in child_docs
]
)
else:
flat_results.extend([parent_dict])
case _:
flat_results = [
result.to_dict() for result in results.object_list
]

csv_rows.extend(flat_results)

if not results.has_next():
if len(csv_rows) <= settings.MAX_SEARCH_RESULTS_EXPORTED:
return csv_rows
break

queryset["page"] = results.next_page_number()

return csv_rows[: settings.MAX_SEARCH_RESULTS_EXPORTED]
83 changes: 82 additions & 1 deletion cl/search/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import csv
import io
import logging
from datetime import date
from datetime import date, datetime
from importlib import import_module
from random import randint
from typing import Any, Generator
Expand All @@ -8,8 +10,12 @@
from celery.canvas import chain
from django.apps import apps
from django.conf import settings
from django.contrib.auth.models import User
from django.core.exceptions import ObjectDoesNotExist
from django.core.mail import EmailMessage
from django.db.models import Prefetch, QuerySet
from django.http import QueryDict
from django.template import loader
from elasticsearch.exceptions import (
ApiError,
ConflictError,
Expand All @@ -34,6 +40,10 @@
from cl.celery_init import app
from cl.lib.elasticsearch_utils import build_daterange_query
from cl.lib.search_index_utils import get_parties_from_case_name
from cl.lib.search_utils import (
fetch_es_results_for_csv,
get_headers_for_search_export,
)
from cl.people_db.models import Person, Position
from cl.search.documents import (
ES_CHILD_ID,
Expand All @@ -45,6 +55,7 @@
PersonDocument,
PositionDocument,
)
from cl.search.forms import SearchForm
from cl.search.models import (
SEARCH_TYPES,
Docket,
Expand Down Expand Up @@ -341,6 +352,76 @@ def document_fields_to_update(
return fields_to_update


@app.task(
autoretry_for=(ConnectionError, ConflictError, ConnectionTimeout),
max_retries=3,
ignore_result=True,
)
def email_search_results(user_id: int, query: str):
"""Sends an email to the user with their search results as a CSV attachment.
:param user_id: The ID of the user to send the email to.
:param query: The user's search query string.
"""
user = User.objects.get(pk=user_id)
# Parse the query string into a dictionary
qd = QueryDict(query.encode(), mutable=True)

# Create a search form instance and validate the query data
search_form = SearchForm(qd)
if not search_form.is_valid():
return

# Get the cleaned data from the validated form
cd = search_form.cleaned_data

# Fetch search results from Elasticsearch based on query and search type
search_results = fetch_es_results_for_csv(
queryset=qd, search_type=cd["type"]
)
if not search_results:
return

# Get the headers for the CSV file based on the search type
csv_headers = get_headers_for_search_export(cd["type"])

# Create the CSV content and store in a StringIO object
csv_content = None
with io.StringIO() as output:
csvwriter = csv.DictWriter(
output,
fieldnames=csv_headers,
extrasaction="ignore",
quotechar='"',
quoting=csv.QUOTE_ALL,
)
csvwriter.writeheader()
for row in search_results:
csvwriter.writerow(row)

csv_content: str = output.getvalue()

# Prepare email content
txt_template = loader.get_template("search_results_email.txt")
email_context = {"username": user.username}

# Create email object
message = EmailMessage(
subject="Your Search Results are Ready!",
body=txt_template.render(email_context),
from_email=settings.DEFAULT_FROM_EMAIL,
to=[user.email],
)

# Generate a filename for the CSV attachment with timestamp
now = datetime.now()
filename = f'search_results_{now.strftime("%Y%m%d_%H%M%S")}.csv'

# Send email with attachments
message.attach(filename, csv_content, "text/csv")
message.send(fail_silently=False)


@app.task(
bind=True,
autoretry_for=(ConnectionError, ConflictError, ConnectionTimeout),
Expand Down
16 changes: 0 additions & 16 deletions cl/search/views.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import logging
from datetime import date, datetime, timedelta, timezone
from urllib.parse import quote

from asgiref.sync import async_to_sync
from cache_memoize import cache_memoize
from django.conf import settings
from django.contrib import messages
from django.contrib.auth.models import User
from django.core.exceptions import PermissionDenied
from django.db.models import Count, Sum
from django.http import HttpRequest, HttpResponse
from django.shortcuts import HttpResponseRedirect, get_object_or_404, render
Expand Down Expand Up @@ -37,19 +34,6 @@
from cl.stats.utils import tally_stat
from cl.visualizations.models import SCOTUSMap

logger = logging.getLogger(__name__)


def check_pagination_depth(page_number):
"""Check if the pagination is too deep (indicating a crawler)"""

if page_number > settings.MAX_SEARCH_PAGINATION_DEPTH:
logger.warning(
"Query depth of %s denied access (probably a crawler)",
page_number,
)
raise PermissionDenied


@cache_memoize(5 * 60)
def get_homepage_stats():
Expand Down

0 comments on commit d4b2f1a

Please sign in to comment.