Skip to content

Commit

Permalink
Rewrite search_notation view and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dchiller committed May 27, 2024
1 parent 0def7ce commit 2856bba
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 154 deletions.
131 changes: 131 additions & 0 deletions app/public/cantusdata/test/core/views/test_search_notation_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from rest_framework.test import APITransactionTestCase
from django.core.management import call_command
from django.urls import reverse

from cantusdata.views.search_notation import SearchNotationView, NotationSearchException

TEST_MEI_FILES_PATH = "cantusdata/test/core/helpers/mei_processing/test_mei_files"


class TestSearchNotationView(APITransactionTestCase):
search_notation_view = SearchNotationView()

def setUp(self) -> None:
call_command(
"index_manuscript_mei",
"123723",
"--min-ngram",
"1",
"--max-ngram",
"5",
"--mei-dir",
TEST_MEI_FILES_PATH,
)

def test_create_query_string(self) -> None:
with self.subTest("Test invalid query"):
with self.assertRaises(NotationSearchException):
self.search_notation_view.create_query_string(
"a_b_q", q_type="pitch_names"
)
with self.subTest("Test valid query"):
query = "u d U r "
query_type = "contour"
query_string = self.search_notation_view.create_query_string(
query, query_type
)
expected_query_string = "contour:u_d_u_r"
self.assertEqual(query_string, expected_query_string)
# We add a separate subtest for a "pitch_names_invariant" query since it
# has a slightly different logic (we get transpositions and chain them
# together with ORs).
with self.subTest("Test pitch_names_invariant query"):
query = "c d e"
query_type = "pitch_names_invariant"
query_string = self.search_notation_view.create_query_string(
query, query_type
)
expected_query_string = "pitch_names:(c_d_e OR d_e_f OR e_f_g OR f_g_a OR g_a_b OR a_b_c OR b_c_d)"
self.assertEqual(query_string, expected_query_string)

def test_do_query(self) -> None:
with self.subTest("Test fields returned"):
# Test that, in general, the fields returned are as expected
expected_results_fields = [
"boxes",
"contour",
"semitones",
"pnames",
]
results, _ = self.search_notation_view.do_query(
123723, "contour:u d u", 100, 0
)
results_fields = list(results[0].keys())
self.assertTrue(set(expected_results_fields).issubset(results_fields))
# Test a case where we know that neume names are returned and ensure
# that the "neumes" field is present in the results
results_neume_names, _ = self.search_notation_view.do_query(
123723, "neume_names:punctum", 100, 0
)
results_neume_names_fields = list(results_neume_names[0].keys())
expected_results_fields.append("neumes")
self.assertTrue(
set(expected_results_fields).issubset(results_neume_names_fields)
)
with self.subTest("Test rows and start parameters"):
results_rows_100_start_0, _ = self.search_notation_view.do_query(
123723, "neume_names:punctum", 100, 0
)
self.assertEqual(len(results_rows_100_start_0), 100)
results_rows_10_start_0, _ = self.search_notation_view.do_query(
123723, "neume_names:punctum", 10, 0
)
self.assertEqual(len(results_rows_10_start_0), 10)
self.assertEqual(results_rows_100_start_0[:10], results_rows_10_start_0)
results_rows_10_start_10, _ = self.search_notation_view.do_query(
123723, "neume_names:punctum", 10, 10
)
self.assertEqual(len(results_rows_10_start_10), 10)
self.assertEqual(results_rows_100_start_0[10:20], results_rows_10_start_10)
with self.subTest("Test manuscript_id parameter"):
_, num_found_123723 = self.search_notation_view.do_query(
123723, "neume_names:punctum", 100, 0
)
_, num_found_123724 = self.search_notation_view.do_query(
123724, "neume_names:punctum", 100, 0
)
self.assertGreater(num_found_123723, 0)
self.assertEqual(num_found_123724, 0)

def test_get(self) -> None:
url = reverse("search-notation-view")
with self.subTest("Test missing required parameters"):
params_no_manuscript: dict[str, str | int] = {
"q": "u d u",
"type": "contour",
}
response_no_manuscript = self.client.get(url, params_no_manuscript)
self.assertEqual(response_no_manuscript.status_code, 400)
params_no_type: dict[str, str | int] = {"q": "u d u", "manuscript": 123723}
response_no_type = self.client.get(url, params_no_type)
self.assertEqual(response_no_type.status_code, 400)
params_no_q: dict[str, str | int] = {
"type": "contour",
"manuscript": 123723,
}
response_no_q = self.client.get(url, params_no_q)
self.assertEqual(response_no_q.status_code, 400)
with self.subTest("Test response"):
params: dict[str, str | int] = {
"q": "u d u",
"type": "contour",
"manuscript": 123723,
}
response = self.client.get(url, params)
self.assertEqual(response.status_code, 200)
response_data = response.json()
self.assertIn("results", response_data)
self.assertIn("numFound", response_data)

def tearDown(self) -> None:
call_command("index_manuscript_mei", "123723", "--flush-index")
Loading

0 comments on commit 2856bba

Please sign in to comment.