diff --git a/chirp/inference/interface.py b/chirp/inference/interface.py index ec1d96d8..3bb4049c 100644 --- a/chirp/inference/interface.py +++ b/chirp/inference/interface.py @@ -119,25 +119,6 @@ def batch_embed(self, audio_batch: np.ndarray) -> InferenceOutputs: """ raise NotImplementedError - def convert_logits( - self, - logits: np.ndarray, - source_class_list: namespace.ClassList, - target_class_list: namespace.ClassList | None, - ) -> np.ndarray: - """Convert model logits to logits for a different class list.""" - if target_class_list is None: - return logits - sp_matrix, sp_mask = source_class_list.get_class_map_matrix( - target_class_list - ) - # When we convert from ClassList A (used for training) to ClassList B - # (for inference output) there may be labels in B which don't appear in A. - # The `sp_mask` tells us which labels appear in both A and B. We set the - # logit for the new labels to NULL_LOGIT, which corresponds to a probability - # very close to zero. - return logits @ sp_matrix + NULL_LOGIT * (1 - sp_mask) - def frame_audio( self, audio_array: np.ndarray, diff --git a/chirp/inference/models.py b/chirp/inference/models.py index 58ea12c9..0dda4657 100644 --- a/chirp/inference/models.py +++ b/chirp/inference/models.py @@ -248,7 +248,6 @@ class TaxonomyModelTF(interface.EmbeddingModel): model: Loaded TF SavedModel. class_list: Loaded class_list for the model's output logits. batchable: Whether the model supports batched input. - target_class_list: If provided, restricts logits to this ClassList. target_peak: Peak normalization value. """ @@ -258,7 +257,6 @@ class TaxonomyModelTF(interface.EmbeddingModel): model: Any # TF SavedModel class_list: namespace.ClassList batchable: bool - target_class_list: namespace.ClassList | None = None target_peak: float | None = 0.25 tfhub_version: int | None = None @@ -340,10 +338,6 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: all_embeddings = np.concatenate([all_embeddings, embeddings], axis=0) all_embeddings = all_embeddings[:, np.newaxis, :] - all_logits = self.convert_logits( - all_logits, self.class_list, self.target_class_list - ) - return interface.InferenceOutputs( all_embeddings, {'label': all_logits}, None ) @@ -362,9 +356,6 @@ def batch_embed( rebatched_audio = framed_audio.reshape([-1, framed_audio.shape[-1]]) logits, embeddings = self.model.infer_tf(rebatched_audio) - logits = self.convert_logits( - logits, self.class_list, self.target_class_list - ) logits = np.reshape(logits, framed_audio.shape[:2] + (logits.shape[-1],)) embeddings = np.reshape( embeddings, framed_audio.shape[:2] + (embeddings.shape[-1],) @@ -382,7 +373,6 @@ class SeparatorModelTF(interface.EmbeddingModel): frame_size: Audio frame size for separation model. model: Loaded TF SavedModel. class_list: Loaded class_list for the model's output logits. - target_class_list: If provided, restricts logits to this ClassList. windows_size_s: Window size for framing audio in samples. The audio will be chunked into frames of size window_size_s, which may help avoid memory blowouts. If None, will simply treat all audio as a single frame. @@ -392,7 +382,6 @@ class SeparatorModelTF(interface.EmbeddingModel): frame_size: int model: Any class_list: namespace.ClassList - target_class_list: namespace.ClassList | None = None window_size_s: float | None = None @classmethod @@ -434,9 +423,6 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: # Recombine batch and time dimensions. sep_audio = np.reshape(sep_audio, [-1, sep_audio.shape[-1]]) all_logits = np.reshape(all_logits, [-1, all_logits.shape[-1]]) - all_logits = self.convert_logits( - all_logits, self.class_list, self.target_class_list - ) all_embeddings = np.reshape(all_embeddings, [-1, all_embeddings.shape[-1]]) return interface.InferenceOutputs( all_embeddings, {'label': all_logits}, sep_audio @@ -459,7 +445,6 @@ class BirdNet(interface.EmbeddingModel): hop_size_s: Hop size for inference. num_tflite_threads: Number of threads to use with TFLite model. class_list_name: Name of the BirdNet class list. - target_class_list: If provided, restricts logits to this ClassList. """ model_path: str @@ -470,7 +455,6 @@ class BirdNet(interface.EmbeddingModel): hop_size_s: float = 3.0 num_tflite_threads: int = 16 class_list_name: str = 'birdnet_v2_1' - target_class_list: namespace.ClassList | None = None @classmethod def from_config(cls, config: config_dict.ConfigDict) -> 'BirdNet': @@ -505,9 +489,6 @@ def embed_saved_model( for window in audio_array[1:]: logits = self.model(window[np.newaxis, :]) all_logits = np.concatenate([all_logits, logits], axis=0) - all_logits = self.convert_logits( - all_logits, self.class_list, self.target_class_list - ) return interface.InferenceOutputs( None, {self.class_list_name: all_logits}, None ) @@ -529,9 +510,6 @@ def embed_tflite(self, audio_array: np.ndarray) -> interface.InferenceOutputs: # Create [Batch, 1, Features] embeddings = np.array(embeddings) logits = np.array(logits) - logits = self.convert_logits( - logits, self.class_list, self.target_class_list - ) return interface.InferenceOutputs( embeddings, {self.class_list_name: logits}, None ) @@ -686,7 +664,6 @@ class PlaceholderModel(interface.EmbeddingModel): make_logits: bool = True make_separated_audio: bool = True do_frame_audio: bool = False - target_class_list: namespace.ClassList | None = None window_size_s: float = 1.0 hop_size_s: float = 1.0 @@ -720,9 +697,6 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: [time_size, len(self.class_list.classes)], np.float32 ), } - outputs['logits']['label'] = self.convert_logits( - outputs['logits']['label'], self.class_list, self.target_class_list - ) if self.make_separated_audio: outputs['separated_audio'] = np.zeros( [2, audio_array.shape[-1]], np.float32 diff --git a/chirp/tests/inference_test.py b/chirp/tests/inference_test.py index 862617f1..66637d07 100644 --- a/chirp/tests/inference_test.py +++ b/chirp/tests/inference_test.py @@ -407,15 +407,12 @@ def test_sep_embed_wrapper(self): make_logits=False, make_separated_audio=True, ) - db = namespace_db.load_db() - target_class_list = db.class_lists['high_sierras'] embeddor = models.PlaceholderModel( sample_rate=22050, make_embeddings=True, make_logits=True, make_separated_audio=False, - target_class_list=target_class_list, ) fake_config = config_dict.ConfigDict() sep_embed = models.SeparateEmbedModel( @@ -438,7 +435,7 @@ def test_sep_embed_wrapper(self): ) # The Sep+Embed model takes the max logits over the channel dimension. self.assertSequenceEqual( - outputs.logits['label'].shape, [5, len(target_class_list.classes)] + outputs.logits['label'].shape, [5, len(embeddor.class_list.classes)] ) def test_pooled_embeddings(self):