Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a cell for displaying loaded validation results. #697

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,35 @@
"roc_auc_estimate = call_density.estimate_roc_auc(validation_examples)\n",
"print(f'Estimated ROC-AUC : {roc_auc_estimate:5.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"id": "0ii9H72iQknv"
},
"outputs": [],
"source": [
"#@title Display Logged Validation Examples. { vertical-output: true }\n",
"\n",
"validation_results = search.TopKSearchResults(top_k=len(validation_examples))\n",
"for v in validation_examples:\n",
" validation_results.update(v.to_search_result(\n",
" target_class, project_state.embedding_model.sample_rate))\n",
"\n",
"samples_per_page = 40 #@param\n",
"page_state = display.PageState(\n",
" np.ceil(combined_results.top_k / samples_per_page))\n",
"\n",
"display.display_paged_results(\n",
" validation_results,\n",
" page_state, samples_per_page,\n",
" project_state=project_state,\n",
" embedding_sample_rate=project_state.embedding_model.sample_rate,\n",
" exclusive_labels=True,\n",
" checkbox_labels=[target_class, f'not {target_class}', 'unsure'],\n",
")"
]
}
],
"metadata": {
Expand Down
24 changes: 24 additions & 0 deletions chirp/inference/call_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from chirp.inference.search import search
from etils import epath
import ipywidgets
import numpy as np
import pandas as pd
import scipy
Expand Down Expand Up @@ -59,6 +60,29 @@ def to_row(self):
self.bin_weight,
]

def to_search_result(self, target_class: str):
"""Convert to a search result for display only."""
result = search.SearchResult(
filename=self.filename,
timestamp_offset=self.timestamp_offset,
score=self.score,
sort_score=np.random.uniform(),
embedding=np.zeros(shape=(0,), dtype=np.float32),
)
b = ipywidgets.RadioButtons(
options=[target_class, f'not {target_class}', 'unsure']
)
if self.is_pos == 1:
b.value = target_class
elif self.is_pos == -1:
b.value = f'not {target_class}'
elif self.is_pos == 0:
b.value = 'unsure'
else:
raise ValueError(f'unexpected value ({self.is_pos})')
result.label_widgets = [b]
return result

@classmethod
def from_search_result(
cls,
Expand Down
3 changes: 2 additions & 1 deletion chirp/inference/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class SearchResult:
# Source file contianing corresponding audio.
filename: str
# Time offset for audio.
timestamp_offset: int
# TODO(tomdenton): Convert to float only; this is measured in seconds.
timestamp_offset: int | float

# The following are populated as needed.
audio: np.ndarray | None = None
Expand Down
28 changes: 28 additions & 0 deletions chirp/inference/tests/call_density_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import shutil
import string
import tempfile
from unittest import mock

from chirp.inference import call_density
from etils import epath
import IPython
import ipywidgets
import numpy as np
from sklearn import metrics

Expand All @@ -31,6 +34,18 @@
class CallDensityTest(absltest.TestCase):

def setUp(self):
# Without this, unit tests using Ipywidgets will fail with 'Comms cannot be
# opened without a kernel and a comm_manager attached to that kernel'. This
# mocks out the comms. This is a little fragile because it sets a private
# attribute and may break for future Ipywidget library upgrades.
setattr(
ipywidgets.Widget,
'_comm_default',
lambda self: mock.MagicMock(spec=IPython.kernel.comm.Comm),
)

super().setUp()

super().setUp()
self.tempdir = tempfile.mkdtemp()

Expand Down Expand Up @@ -156,6 +171,19 @@ def test_write_read_log(self):
got_examples = call_density.load_validation_log(log_filepath)
self.assertLen(got_examples, len(examples))

with self.subTest('to_result'):
r = got_examples[0].to_search_result('someclass')
self.assertEqual(r.filename, got_examples[0].filename)
self.assertEqual(r.timestamp_offset, got_examples[0].timestamp_offset)
self.assertEqual(r.score, got_examples[0].score)
if examples[0].is_pos == 1:
self.assertEqual(r.label_widgets[0].value, 'someclass')
elif examples[0].is_pos == -1:
self.assertEqual(r.label_widgets[0].value, 'not someclass')
elif examples[0].is_pos == 0:
self.assertEqual(r.label_widgets[0].value, 'unsure')
else:
raise ValueError(f'unexpected value ({examples[0].is_pos})')

if __name__ == '__main__':
absltest.main()
Loading