From e1d42eb6b44b4ab9a147cc9906b4cbb4a1e192ce Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Fri, 18 Oct 2024 12:16:00 -0700 Subject: [PATCH] Add a cell for displaying loaded validation results. PiperOrigin-RevId: 687384608 --- analysis.ipynb | 29 ++++++++++++++++++++++ chirp/inference/call_density.py | 24 ++++++++++++++++++ chirp/inference/search/search.py | 3 ++- chirp/inference/tests/call_density_test.py | 28 +++++++++++++++++++++ 4 files changed, 83 insertions(+), 1 deletion(-) diff --git a/analysis.ipynb b/analysis.ipynb index 316f4a98..cddc15e1 100644 --- a/analysis.ipynb +++ b/analysis.ipynb @@ -312,6 +312,35 @@ "roc_auc_estimate = call_density.estimate_roc_auc(validation_examples)\n", "print(f'Estimated ROC-AUC : {roc_auc_estimate:5.4f}')" ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "id": "0ii9H72iQknv" + }, + "outputs": [], + "source": [ + "#@title Display Logged Validation Examples. { vertical-output: true }\n", + "\n", + "validation_results = search.TopKSearchResults(top_k=len(validation_examples))\n", + "for v in validation_examples:\n", + " validation_results.update(v.to_search_result(\n", + " target_class, project_state.embedding_model.sample_rate))\n", + "\n", + "samples_per_page = 40 #@param\n", + "page_state = display.PageState(\n", + " np.ceil(combined_results.top_k / samples_per_page))\n", + "\n", + "display.display_paged_results(\n", + " validation_results,\n", + " page_state, samples_per_page,\n", + " project_state=project_state,\n", + " embedding_sample_rate=project_state.embedding_model.sample_rate,\n", + " exclusive_labels=True,\n", + " checkbox_labels=[target_class, f'not {target_class}', 'unsure'],\n", + ")" + ] } ], "metadata": { diff --git a/chirp/inference/call_density.py b/chirp/inference/call_density.py index dac25a38..a40531d1 100644 --- a/chirp/inference/call_density.py +++ b/chirp/inference/call_density.py @@ -21,6 +21,7 @@ from chirp.inference.search import search from etils import epath +import ipywidgets import numpy as np import pandas as pd import scipy @@ -59,6 +60,29 @@ def to_row(self): self.bin_weight, ] + def to_search_result(self, target_class: str): + """Convert to a search result for display only.""" + result = search.SearchResult( + filename=self.filename, + timestamp_offset=self.timestamp_offset, + score=self.score, + sort_score=np.random.uniform(), + embedding=np.zeros(shape=(0,), dtype=np.float32), + ) + b = ipywidgets.RadioButtons( + options=[target_class, f'not {target_class}', 'unsure'] + ) + if self.is_pos == 1: + b.value = target_class + elif self.is_pos == -1: + b.value = f'not {target_class}' + elif self.is_pos == 0: + b.value = 'unsure' + else: + raise ValueError(f'unexpected value ({self.is_pos})') + result.label_widgets = [b] + return result + @classmethod def from_search_result( cls, diff --git a/chirp/inference/search/search.py b/chirp/inference/search/search.py index 2e6b0c35..cdcc2b5b 100644 --- a/chirp/inference/search/search.py +++ b/chirp/inference/search/search.py @@ -42,7 +42,8 @@ class SearchResult: # Source file contianing corresponding audio. filename: str # Time offset for audio. - timestamp_offset: int + # TODO(tomdenton): Convert to float only; this is measured in seconds. + timestamp_offset: int | float # The following are populated as needed. audio: np.ndarray | None = None diff --git a/chirp/inference/tests/call_density_test.py b/chirp/inference/tests/call_density_test.py index ed3d5dd0..0cae23ee 100644 --- a/chirp/inference/tests/call_density_test.py +++ b/chirp/inference/tests/call_density_test.py @@ -19,9 +19,12 @@ import shutil import string import tempfile +from unittest import mock from chirp.inference import call_density from etils import epath +import IPython +import ipywidgets import numpy as np from sklearn import metrics @@ -31,6 +34,18 @@ class CallDensityTest(absltest.TestCase): def setUp(self): + # Without this, unit tests using Ipywidgets will fail with 'Comms cannot be + # opened without a kernel and a comm_manager attached to that kernel'. This + # mocks out the comms. This is a little fragile because it sets a private + # attribute and may break for future Ipywidget library upgrades. + setattr( + ipywidgets.Widget, + '_comm_default', + lambda self: mock.MagicMock(spec=IPython.kernel.comm.Comm), + ) + + super().setUp() + super().setUp() self.tempdir = tempfile.mkdtemp() @@ -156,6 +171,19 @@ def test_write_read_log(self): got_examples = call_density.load_validation_log(log_filepath) self.assertLen(got_examples, len(examples)) + with self.subTest('to_result'): + r = got_examples[0].to_search_result('someclass') + self.assertEqual(r.filename, got_examples[0].filename) + self.assertEqual(r.timestamp_offset, got_examples[0].timestamp_offset) + self.assertEqual(r.score, got_examples[0].score) + if examples[0].is_pos == 1: + self.assertEqual(r.label_widgets[0].value, 'someclass') + elif examples[0].is_pos == -1: + self.assertEqual(r.label_widgets[0].value, 'not someclass') + elif examples[0].is_pos == 0: + self.assertEqual(r.label_widgets[0].value, 'unsure') + else: + raise ValueError(f'unexpected value ({examples[0].is_pos})') if __name__ == '__main__': absltest.main()