diff --git a/app/public/cantusdata/test/core/views/test_search_notation_view.py b/app/public/cantusdata/test/core/views/test_search_notation_view.py new file mode 100644 index 00000000..d583b755 --- /dev/null +++ b/app/public/cantusdata/test/core/views/test_search_notation_view.py @@ -0,0 +1,131 @@ +from rest_framework.test import APITransactionTestCase +from django.core.management import call_command +from django.urls import reverse + +from cantusdata.views.search_notation import SearchNotationView, NotationSearchException + +TEST_MEI_FILES_PATH = "cantusdata/test/core/helpers/mei_processing/test_mei_files" + + +class TestSearchNotationView(APITransactionTestCase): + search_notation_view = SearchNotationView() + + def setUp(self) -> None: + call_command( + "index_manuscript_mei", + "123723", + "--min-ngram", + "1", + "--max-ngram", + "5", + "--mei-dir", + TEST_MEI_FILES_PATH, + ) + + def test_create_query_string(self) -> None: + with self.subTest("Test invalid query"): + with self.assertRaises(NotationSearchException): + self.search_notation_view.create_query_string( + "a_b_q", q_type="pitch_names" + ) + with self.subTest("Test valid query"): + query = "u d U r " + query_type = "contour" + query_string = self.search_notation_view.create_query_string( + query, query_type + ) + expected_query_string = "contour:u_d_u_r" + self.assertEqual(query_string, expected_query_string) + # We add a separate subtest for a "pitch_names_invariant" query since it + # has a slightly different logic (we get transpositions and chain them + # together with ORs). + with self.subTest("Test pitch_names_invariant query"): + query = "c d e" + query_type = "pitch_names_invariant" + query_string = self.search_notation_view.create_query_string( + query, query_type + ) + expected_query_string = "pitch_names:(c_d_e OR d_e_f OR e_f_g OR f_g_a OR g_a_b OR a_b_c OR b_c_d)" + self.assertEqual(query_string, expected_query_string) + + def test_do_query(self) -> None: + with self.subTest("Test fields returned"): + # Test that, in general, the fields returned are as expected + expected_results_fields = [ + "boxes", + "contour", + "semitones", + "pnames", + ] + results, _ = self.search_notation_view.do_query( + 123723, "contour:u d u", 100, 0 + ) + results_fields = list(results[0].keys()) + self.assertTrue(set(expected_results_fields).issubset(results_fields)) + # Test a case where we know that neume names are returned and ensure + # that the "neumes" field is present in the results + results_neume_names, _ = self.search_notation_view.do_query( + 123723, "neume_names:punctum", 100, 0 + ) + results_neume_names_fields = list(results_neume_names[0].keys()) + expected_results_fields.append("neumes") + self.assertTrue( + set(expected_results_fields).issubset(results_neume_names_fields) + ) + with self.subTest("Test rows and start parameters"): + results_rows_100_start_0, _ = self.search_notation_view.do_query( + 123723, "neume_names:punctum", 100, 0 + ) + self.assertEqual(len(results_rows_100_start_0), 100) + results_rows_10_start_0, _ = self.search_notation_view.do_query( + 123723, "neume_names:punctum", 10, 0 + ) + self.assertEqual(len(results_rows_10_start_0), 10) + self.assertEqual(results_rows_100_start_0[:10], results_rows_10_start_0) + results_rows_10_start_10, _ = self.search_notation_view.do_query( + 123723, "neume_names:punctum", 10, 10 + ) + self.assertEqual(len(results_rows_10_start_10), 10) + self.assertEqual(results_rows_100_start_0[10:20], results_rows_10_start_10) + with self.subTest("Test manuscript_id parameter"): + _, num_found_123723 = self.search_notation_view.do_query( + 123723, "neume_names:punctum", 100, 0 + ) + _, num_found_123724 = self.search_notation_view.do_query( + 123724, "neume_names:punctum", 100, 0 + ) + self.assertGreater(num_found_123723, 0) + self.assertEqual(num_found_123724, 0) + + def test_get(self) -> None: + url = reverse("search-notation-view") + with self.subTest("Test missing required parameters"): + params_no_manuscript: dict[str, str | int] = { + "q": "u d u", + "type": "contour", + } + response_no_manuscript = self.client.get(url, params_no_manuscript) + self.assertEqual(response_no_manuscript.status_code, 400) + params_no_type: dict[str, str | int] = {"q": "u d u", "manuscript": 123723} + response_no_type = self.client.get(url, params_no_type) + self.assertEqual(response_no_type.status_code, 400) + params_no_q: dict[str, str | int] = { + "type": "contour", + "manuscript": 123723, + } + response_no_q = self.client.get(url, params_no_q) + self.assertEqual(response_no_q.status_code, 400) + with self.subTest("Test response"): + params: dict[str, str | int] = { + "q": "u d u", + "type": "contour", + "manuscript": 123723, + } + response = self.client.get(url, params) + self.assertEqual(response.status_code, 200) + response_data = response.json() + self.assertIn("results", response_data) + self.assertIn("numFound", response_data) + + def tearDown(self) -> None: + call_command("index_manuscript_mei", "123723", "--flush-index") diff --git a/app/public/cantusdata/views/search_notation.py b/app/public/cantusdata/views/search_notation.py index 4f1fd7d7..bac85719 100644 --- a/app/public/cantusdata/views/search_notation.py +++ b/app/public/cantusdata/views/search_notation.py @@ -1,19 +1,48 @@ -from django.conf import settings +from typing import Any, TypedDict, Union, Optional, NotRequired +import requests +from django.conf import settings from rest_framework.views import APIView from rest_framework.response import Response +from rest_framework.request import Request from rest_framework.exceptions import APIException -from cantusdata.helpers import search_utils -import solr -import json -import types -from operator import itemgetter +from cantusdata.helpers.search_utils import validate_query, get_transpositions + +RETURNED_FIELDS = [ + "manuscript_id", + "folio", + "image_uri", + "pitch_names", + "contour", + "semitone_intervals", + "neume_names", + "location_json", +] + + +class SolrQueryResultItem(TypedDict): + manuscript_id: int + folio: str + image_uri: str + pitch_names: str + contour: str + semitone_intervals: str + neume_names: str + location_json: list[dict[str, int]] + + +class NotationSearchResultItem(TypedDict): + boxes: list[dict[str, Union[int, str]]] + contour: list[str] + semitones: list[str] + pnames: list[str] + neumes: NotRequired[list[str]] -class NotationException(APIException): +class NotationSearchException(APIException): status_code = 400 - default_detail = "Notation search request invalid" + default_detail = "Notation search request invalid." class SearchNotationView(APIView): @@ -21,156 +50,97 @@ class SearchNotationView(APIView): Search algorithm adapted from the Liber Usualis code """ - def get(self, request, *args, **kwargs): + def get(self, request: Request, *args: Any, **kwargs: Any) -> Response: q = request.GET.get("q", None) stype = request.GET.get("type", None) - manuscript = request.GET.get("manuscript", None) - rows = request.GET.get("rows", "100") - start = request.GET.get("start", "0") - - # Give a 400 if there's a notation exception, and let - # anything else give a 500 - results = self.do_query(manuscript, stype, q) - - return Response({"numFound": len(results), "results": results}) - - def do_query(self, manuscript, qtype, query): - # This will be appended to the search query so that we only get - # data from the manuscript that we want! - manuscript_query = ' AND siglum_slug:"{0}"'.format(manuscript) - - solrconn = solr.SolrConnection(settings.SOLR_SERVER) - - # Normalize case and whitespace - query = " ".join(elem for elem in query.lower().split()) - - if qtype == "neumes": - query_stmt = "neumes:{0}".format( - # query - query.replace(" ", "_") - ) - elif qtype == "pnames" or qtype == "pnames-invariant": - if not search_utils.valid_pitch_sequence(query): - raise NotationException( - "The query you provided is not a valid pitch sequence" - ) - real_query = ( - query - if qtype == "pnames" - else " OR ".join(search_utils.get_transpositions(query)) - ) - query_stmt = "pnames:({0})".format(real_query) - elif qtype == "contour": - query_stmt = "contour:{0}".format(query) - elif qtype == "text": - query_stmt = "text:{0}".format(query) - elif qtype == "intervals": - query_stmt = "intervals:{0}".format(query.replace(" ", "_")) - elif qtype == "incipit": - query_stmt = "incipit:{0}*".format(query) - else: - raise NotationException("Invalid query type provided") - - if qtype == "pnames-invariant": - print(query_stmt + manuscript_query) - response = solrconn.query( - query_stmt + manuscript_query, - score=False, - sort="folio asc", - q_op="OR", - rows=1000000, + manuscript_param = request.GET.get("manuscript", None) + rows_param = request.GET.get("rows", "100") + start_param = request.GET.get("start", "0") + + # Do some parameter validation and cast to appropriate types + if not q or not stype or not manuscript_param: + raise NotationSearchException("Missing required parameters.") + if not rows_param.isdigit() or not start_param.isdigit(): + raise NotationSearchException("'Rows' and 'Start' must be integers.") + if not manuscript_param.isdigit(): + raise NotationSearchException("Manuscript ID must be digits.") + rows = int(rows_param) + start = int(start_param) + manuscript_id = int(manuscript_param) + + query_str = self.create_query_string(q, stype) + results, num_found = self.do_query(manuscript_id, query_str, rows, start) + + return Response({"numFound": num_found, "results": results}) + + def create_query_string(self, q: str, q_type: str) -> str: + """ + Format the query and query type into a string for solr. + """ + normalized_q_elems = q.lower().split() + q_valid = validate_query(normalized_q_elems, q_type) + if not q_valid: + raise NotationSearchException("Invalid query.") + if q_type == "pitch_names_invariant": + transpositions = get_transpositions(normalized_q_elems) + # Create a query string for the tranpositions: + # e.g. if transpositions are [["a", "b" , "c"], "b","c","d"], etc.] + # then the query string will be: + # "(a_b_c OR b_c_d OR etc.)" + q_str = " OR ".join( + "_".join(pitch for pitch in transposition) + for transposition in transpositions ) + q_str = f"({q_str})" + q_type = "pitch_names" else: - print(query_stmt + manuscript_query) - response = solrconn.query( - query_stmt + manuscript_query, - score=False, - sort="folio asc", - rows=1000000, - ) - - results = [] - - # get only the longest ngram in the results, for results which are associated with - # a pitch sequence - if qtype == "neumes": - if manuscript == "ch-sgs-390": - pass - else: - notegrams_num = search_utils.get_neumes_length(query) - response = [ - r - for r in response - if not r.get("pnames") or len(r["pnames"]) == notegrams_num - ] - - box_sort_key = itemgetter("p", "y") - - for d in response: - image_uri = d["image_uri"] - folio = d["folio"] - locations = json.loads(d["location"].replace("'", '"')) - - if isinstance(locations, dict): - box_w = locations["width"] - box_h = locations["height"] - box_x = locations["ulx"] - box_y = locations["uly"] - boxes = [ - { - "p": image_uri, - "f": folio, - "w": box_w, - "h": box_h, - "x": box_x, - "y": box_y, - } - ] - else: - boxes = [] - - for location in locations: - box_w = location["width"] - box_h = location["height"] - box_x = location["ulx"] - box_y = location["uly"] - boxes.append( - { - "p": image_uri, - "f": folio, - "w": box_w, - "h": box_h, - "x": box_x, - "y": box_y, - } - ) - - boxes.sort(key=box_sort_key) - - results.append( + q_str = "_".join(normalized_q_elems) + return f"{q_type}:{q_str}" + + def create_boxes( + self, locations: list[dict[str, int]], image_uri: str, folio: str + ) -> list[dict[str, Union[int, str]]]: + boxes: list[dict[str, Union[int, str]]] = [] + for location in locations: + boxes.append( { - "boxes": boxes, - "contour": get_value(d, "contour", list), - "intervals": get_value(d, "intervals", lambda i: i.split("_")), - "neumes": get_value(d, "neumes", lambda i: i.split("_")), - "pnames": get_value(d, "pnames", list), - "semitones": get_value( - d, - "semitones", - lambda tones: [int(s) for s in tones.split("_")], - ), + "p": image_uri, + "f": folio, + "w": location["width"], + "h": location["height"], + "x": location["ulx"], + "y": location["uly"], } ) - - results.sort(key=lambda result: [box_sort_key(box) for box in result["boxes"]]) - - return results - - -def get_value(d, key, transform): - try: - value = d[key] - except KeyError: - return None - - return transform(value) + return boxes + + def do_query( + self, manuscript_id: int, q_str: str, rows: int, start: int + ) -> tuple[list[NotationSearchResultItem], int]: + # Add type and manuscript parameters to the query string + query_str_w_manuscript = ( + f"type:omr_ngram AND manuscript_id:{manuscript_id} AND {q_str}" + ) + complete_query_str = ( + f"{settings.SOLR_SERVER}/select?q=*:*&fq={query_str_w_manuscript}" + f"&fl={','.join(RETURNED_FIELDS)}&sort=folio+asc&rows={rows}&start={start}" + ) + response = requests.get(complete_query_str, timeout=10).json() + request_results: list[SolrQueryResultItem] = response["response"]["docs"] + num_found = response["response"]["numFound"] + results = [] + for d in request_results: + boxes = self.create_boxes(d["location_json"], d["image_uri"], d["folio"]) + result: NotationSearchResultItem = { + "boxes": boxes, + "contour": d["contour"].split("_"), + "semitones": d["semitone_intervals"].split("_"), + "pnames": d["pitch_names"].split("_"), + } + neume_names: Optional[str] = d.get("neume_names") + if neume_names: + neume_names_list = neume_names.split("_") + result["neumes"] = neume_names_list + results.append(result) + + return results, num_found