Skip to content

Commit

Permalink
Remove target_class_list conversion from EmbeddingModel. It is slow, …
Browse files Browse the repository at this point in the history
…and can better be handled by the caller.

PiperOrigin-RevId: 580298450
  • Loading branch information
sdenton4 authored and copybara-github committed Nov 7, 2023
1 parent 661b894 commit 4c65116
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 49 deletions.
19 changes: 0 additions & 19 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,6 @@ def batch_embed(self, audio_batch: np.ndarray) -> InferenceOutputs:
"""
raise NotImplementedError

def convert_logits(
self,
logits: np.ndarray,
source_class_list: namespace.ClassList,
target_class_list: namespace.ClassList | None,
) -> np.ndarray:
"""Convert model logits to logits for a different class list."""
if target_class_list is None:
return logits
sp_matrix, sp_mask = source_class_list.get_class_map_matrix(
target_class_list
)
# When we convert from ClassList A (used for training) to ClassList B
# (for inference output) there may be labels in B which don't appear in A.
# The `sp_mask` tells us which labels appear in both A and B. We set the
# logit for the new labels to NULL_LOGIT, which corresponds to a probability
# very close to zero.
return logits @ sp_matrix + NULL_LOGIT * (1 - sp_mask)

def frame_audio(
self,
audio_array: np.ndarray,
Expand Down
26 changes: 0 additions & 26 deletions chirp/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ class TaxonomyModelTF(interface.EmbeddingModel):
model: Loaded TF SavedModel.
class_list: Loaded class_list for the model's output logits.
batchable: Whether the model supports batched input.
target_class_list: If provided, restricts logits to this ClassList.
target_peak: Peak normalization value.
"""

Expand All @@ -258,7 +257,6 @@ class TaxonomyModelTF(interface.EmbeddingModel):
model: Any # TF SavedModel
class_list: namespace.ClassList
batchable: bool
target_class_list: namespace.ClassList | None = None
target_peak: float | None = 0.25
tfhub_version: int | None = None

Expand Down Expand Up @@ -340,10 +338,6 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs:
all_embeddings = np.concatenate([all_embeddings, embeddings], axis=0)
all_embeddings = all_embeddings[:, np.newaxis, :]

all_logits = self.convert_logits(
all_logits, self.class_list, self.target_class_list
)

return interface.InferenceOutputs(
all_embeddings, {'label': all_logits}, None
)
Expand All @@ -362,9 +356,6 @@ def batch_embed(

rebatched_audio = framed_audio.reshape([-1, framed_audio.shape[-1]])
logits, embeddings = self.model.infer_tf(rebatched_audio)
logits = self.convert_logits(
logits, self.class_list, self.target_class_list
)
logits = np.reshape(logits, framed_audio.shape[:2] + (logits.shape[-1],))
embeddings = np.reshape(
embeddings, framed_audio.shape[:2] + (embeddings.shape[-1],)
Expand All @@ -382,7 +373,6 @@ class SeparatorModelTF(interface.EmbeddingModel):
frame_size: Audio frame size for separation model.
model: Loaded TF SavedModel.
class_list: Loaded class_list for the model's output logits.
target_class_list: If provided, restricts logits to this ClassList.
windows_size_s: Window size for framing audio in samples. The audio will be
chunked into frames of size window_size_s, which may help avoid memory
blowouts. If None, will simply treat all audio as a single frame.
Expand All @@ -392,7 +382,6 @@ class SeparatorModelTF(interface.EmbeddingModel):
frame_size: int
model: Any
class_list: namespace.ClassList
target_class_list: namespace.ClassList | None = None
window_size_s: float | None = None

@classmethod
Expand Down Expand Up @@ -434,9 +423,6 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs:
# Recombine batch and time dimensions.
sep_audio = np.reshape(sep_audio, [-1, sep_audio.shape[-1]])
all_logits = np.reshape(all_logits, [-1, all_logits.shape[-1]])
all_logits = self.convert_logits(
all_logits, self.class_list, self.target_class_list
)
all_embeddings = np.reshape(all_embeddings, [-1, all_embeddings.shape[-1]])
return interface.InferenceOutputs(
all_embeddings, {'label': all_logits}, sep_audio
Expand All @@ -459,7 +445,6 @@ class BirdNet(interface.EmbeddingModel):
hop_size_s: Hop size for inference.
num_tflite_threads: Number of threads to use with TFLite model.
class_list_name: Name of the BirdNet class list.
target_class_list: If provided, restricts logits to this ClassList.
"""

model_path: str
Expand All @@ -470,7 +455,6 @@ class BirdNet(interface.EmbeddingModel):
hop_size_s: float = 3.0
num_tflite_threads: int = 16
class_list_name: str = 'birdnet_v2_1'
target_class_list: namespace.ClassList | None = None

@classmethod
def from_config(cls, config: config_dict.ConfigDict) -> 'BirdNet':
Expand Down Expand Up @@ -505,9 +489,6 @@ def embed_saved_model(
for window in audio_array[1:]:
logits = self.model(window[np.newaxis, :])
all_logits = np.concatenate([all_logits, logits], axis=0)
all_logits = self.convert_logits(
all_logits, self.class_list, self.target_class_list
)
return interface.InferenceOutputs(
None, {self.class_list_name: all_logits}, None
)
Expand All @@ -529,9 +510,6 @@ def embed_tflite(self, audio_array: np.ndarray) -> interface.InferenceOutputs:
# Create [Batch, 1, Features]
embeddings = np.array(embeddings)
logits = np.array(logits)
logits = self.convert_logits(
logits, self.class_list, self.target_class_list
)
return interface.InferenceOutputs(
embeddings, {self.class_list_name: logits}, None
)
Expand Down Expand Up @@ -686,7 +664,6 @@ class PlaceholderModel(interface.EmbeddingModel):
make_logits: bool = True
make_separated_audio: bool = True
do_frame_audio: bool = False
target_class_list: namespace.ClassList | None = None
window_size_s: float = 1.0
hop_size_s: float = 1.0

Expand Down Expand Up @@ -720,9 +697,6 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs:
[time_size, len(self.class_list.classes)], np.float32
),
}
outputs['logits']['label'] = self.convert_logits(
outputs['logits']['label'], self.class_list, self.target_class_list
)
if self.make_separated_audio:
outputs['separated_audio'] = np.zeros(
[2, audio_array.shape[-1]], np.float32
Expand Down
5 changes: 1 addition & 4 deletions chirp/tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,12 @@ def test_sep_embed_wrapper(self):
make_logits=False,
make_separated_audio=True,
)
db = namespace_db.load_db()
target_class_list = db.class_lists['high_sierras']

embeddor = models.PlaceholderModel(
sample_rate=22050,
make_embeddings=True,
make_logits=True,
make_separated_audio=False,
target_class_list=target_class_list,
)
fake_config = config_dict.ConfigDict()
sep_embed = models.SeparateEmbedModel(
Expand All @@ -438,7 +435,7 @@ def test_sep_embed_wrapper(self):
)
# The Sep+Embed model takes the max logits over the channel dimension.
self.assertSequenceEqual(
outputs.logits['label'].shape, [5, len(target_class_list.classes)]
outputs.logits['label'].shape, [5, len(embeddor.class_list.classes)]
)

def test_pooled_embeddings(self):
Expand Down

0 comments on commit 4c65116

Please sign in to comment.