diff --git a/chirp/inference/embed_lib.py b/chirp/inference/embed_lib.py index 20c7599e..667a2b13 100644 --- a/chirp/inference/embed_lib.py +++ b/chirp/inference/embed_lib.py @@ -159,6 +159,7 @@ def __init__( write_logits: bool | Sequence[str], write_separated_audio: bool, write_raw_audio: bool, + write_frontend: bool, model_key: str, model_config: config_dict.ConfigDict, file_id_depth: int, @@ -177,6 +178,8 @@ def __init__( logit keys to write. write_separated_audio: Whether to write out separated audio tracks. write_raw_audio: If true, will add the original audio to the output. + write_frontend: If true, will add the model's frontend (spectrogram) to + the output. model_key: String indicating which model wrapper to use. See MODEL_KEYS. Only used for setting up the embedding model. model_config: Keyword arg dictionary for the model wrapper class. Only @@ -201,6 +204,7 @@ def __init__( self.write_logits = write_logits self.write_separated_audio = write_separated_audio self.write_raw_audio = write_raw_audio + self.write_frontend = write_frontend self.crop_s = crop_s self.embedding_model = embedding_model self.file_id_depth = file_id_depth @@ -277,6 +281,7 @@ def audio_to_example( write_separated_audio=self.write_separated_audio, write_embeddings=self.write_embeddings, write_logits=write_logits, + write_frontend=self.write_frontend, tensor_dtype=self.tensor_dtype, ) return example diff --git a/chirp/inference/interface.py b/chirp/inference/interface.py index 577db86e..fa17d9f1 100644 --- a/chirp/inference/interface.py +++ b/chirp/inference/interface.py @@ -245,26 +245,24 @@ def embed_from_batch_embed_fn( audio_batch = audio_array[np.newaxis, :] outputs = embed_fn(audio_batch) - if outputs.embeddings is not None: - embeddings = outputs.embeddings[0] - else: - embeddings = None + unbatched_outputs = {} + for k in ['embeddings', 'separated_audio', 'frontend']: + if getattr(outputs, k) is not None: + unbatched_outputs[k] = getattr(outputs, k)[0] + else: + unbatched_outputs[k] = None + if outputs.logits is not None: logits = {} for k, v in outputs.logits.items(): logits[k] = v[0] else: logits = None - if outputs.separated_audio is not None: - separated_audio = outputs.separated_audio[0] - else: - separated_audio = None return InferenceOutputs( - embeddings=embeddings, logits=logits, - separated_audio=separated_audio, batched=False, + **unbatched_outputs, ) @@ -275,10 +273,13 @@ def batch_embed_from_embed_fn( outputs = [] for audio in audio_batch: outputs.append(embed_fn(audio)) - if outputs[0].embeddings is not None: - embeddings = np.stack([x.embeddings for x in outputs], axis=0) - else: - embeddings = None + + batched_outputs = {} + for k in ['embeddings', 'separated_audio', 'frontend']: + if getattr(outputs[0], k) is not None: + batched_outputs[k] = np.stack([getattr(x, k) for x in outputs], axis=0) + else: + batched_outputs[k] = None if outputs[0].logits is not None: batched_logits = {} @@ -289,16 +290,10 @@ def batch_embed_from_embed_fn( else: batched_logits = None - if outputs[0].separated_audio is not None: - separated_audio = np.stack([x.separated_audio for x in outputs], axis=0) - else: - separated_audio = None - return InferenceOutputs( - embeddings=embeddings, logits=batched_logits, - separated_audio=separated_audio, batched=True, + **batched_outputs, ) diff --git a/chirp/inference/models.py b/chirp/inference/models.py index d366d097..8cd3c852 100644 --- a/chirp/inference/models.py +++ b/chirp/inference/models.py @@ -678,9 +678,11 @@ class PlaceholderModel(interface.EmbeddingModel): make_embeddings: bool = True make_logits: bool = True make_separated_audio: bool = True + make_frontend: bool = True do_frame_audio: bool = False window_size_s: float = 1.0 hop_size_s: float = 1.0 + frontend_size: tuple[int, int] = (32, 32) @classmethod def from_config(cls, config: config_dict.ConfigDict) -> 'PlaceholderModel': @@ -703,6 +705,10 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs: outputs['embeddings'] = np.zeros( [time_size, 1, self.embedding_size], np.float32 ) + if self.make_frontend: + outputs['frontend'] = np.zeros( + [time_size, self.frontend_size[0], self.frontend_size[1]], np.float32 + ) if self.make_logits: outputs['logits'] = { 'label': np.zeros( diff --git a/chirp/inference/tf_examples.py b/chirp/inference/tf_examples.py index 91069387..a517efb4 100644 --- a/chirp/inference/tf_examples.py +++ b/chirp/inference/tf_examples.py @@ -36,6 +36,8 @@ SEPARATED_AUDIO_SHAPE = 'separated_audio_shape' RAW_AUDIO = 'raw_audio' RAW_AUDIO_SHAPE = 'raw_audio_shape' +FRONTEND = 'frontend' +FRONTEND_SHAPE = 'frontend_shape' def get_feature_description(logit_names: Sequence[str] | None = None): @@ -62,6 +64,10 @@ def get_feature_description(logit_names: Sequence[str] | None = None): SEPARATED_AUDIO_SHAPE: tf.io.FixedLenSequenceFeature( [], tf.int64, allow_missing=True ), + FRONTEND: tf.io.FixedLenFeature([], tf.string, default_value=''), + FRONTEND_SHAPE: tf.io.FixedLenSequenceFeature( + [], tf.int64, allow_missing=True + ), RAW_AUDIO: tf.io.FixedLenFeature([], tf.string, default_value=''), RAW_AUDIO_SHAPE: tf.io.FixedLenSequenceFeature( [], tf.int64, allow_missing=True @@ -86,7 +92,7 @@ def get_example_parser( def _parser(ex): ex = tf.io.parse_single_example(ex, features) - tensor_keys = [EMBEDDING, SEPARATED_AUDIO, RAW_AUDIO] + tensor_keys = [EMBEDDING, SEPARATED_AUDIO, RAW_AUDIO, FRONTEND] if logit_names is not None: tensor_keys.extend(logit_names) for key in tensor_keys: @@ -140,6 +146,7 @@ def model_outputs_to_tf_example( write_logits: bool | Sequence[str], write_separated_audio: bool, write_raw_audio: bool, + write_frontend: bool, tensor_dtype: str = 'float32', ) -> tf.train.Example: """Create a TFExample from InferenceOutputs.""" @@ -153,6 +160,26 @@ def model_outputs_to_tf_example( ) feature[EMBEDDING_SHAPE] = (int_feature(model_outputs.embeddings.shape),) + if write_separated_audio and model_outputs.separated_audio is not None: + feature[SEPARATED_AUDIO] = bytes_feature( + serialize_tensor(model_outputs.separated_audio, tensor_dtype) + ) + feature[SEPARATED_AUDIO_SHAPE] = int_feature( + model_outputs.separated_audio.shape + ) + + if write_frontend and model_outputs.frontend is not None: + feature[FRONTEND] = bytes_feature( + serialize_tensor(model_outputs.frontend, tensor_dtype) + ) + feature[FRONTEND_SHAPE] = int_feature(model_outputs.frontend.shape) + + if write_raw_audio: + feature[RAW_AUDIO] = bytes_feature( + serialize_tensor(tf.constant(audio, dtype=tf.float32), tensor_dtype) + ) + feature[RAW_AUDIO_SHAPE] = int_feature(audio.shape) + # Handle writing logits. if model_outputs.logits is not None and write_logits: logit_keys = tuple(model_outputs.logits.keys()) @@ -166,18 +193,6 @@ def model_outputs_to_tf_example( ) feature[logits_key + '_shape'] = int_feature(logits.shape) - if write_separated_audio and model_outputs.separated_audio is not None: - feature[SEPARATED_AUDIO] = bytes_feature( - serialize_tensor(model_outputs.separated_audio, tensor_dtype) - ) - feature[SEPARATED_AUDIO_SHAPE] = int_feature( - model_outputs.separated_audio.shape - ) - if write_raw_audio: - feature[RAW_AUDIO] = bytes_feature( - serialize_tensor(tf.constant(audio, dtype=tf.float32), tensor_dtype) - ) - feature[RAW_AUDIO_SHAPE] = int_feature(audio.shape) ex = tf.train.Example(features=tf.train.Features(feature=feature)) return ex diff --git a/chirp/tests/inference_test.py b/chirp/tests/inference_test.py index 3036413b..e559daec 100644 --- a/chirp/tests/inference_test.py +++ b/chirp/tests/inference_test.py @@ -49,26 +49,49 @@ def _make_output_head_model(model_path: str, embedding_dim: int = 1280): class InferenceTest(parameterized.TestCase): - @parameterized.product( - make_embeddings=(True, False), - make_logits=(True, False), - make_separated_audio=(True, False), - write_embeddings=(True, False), - write_logits=(True, False), - write_separated_audio=(True, False), - write_raw_audio=(True, False), - tensor_dtype=('float32', 'float16'), + @parameterized.parameters( + # Test each output type individually. + {'make_embeddings': True}, + {'make_embeddings': True, 'write_embeddings': True}, + {'make_logits': True}, + {'make_logits': True, 'write_logits': True}, + {'make_separated_audio': True}, + {'make_separated_audio': True, 'write_separated_audio': True}, + {'make_frontend': True}, + {'make_frontend': True, 'write_frontend': True}, + {'write_raw_audio': True}, + # Check float16 handling. + {'make_embeddings': True, 'tensor_dtype': 'float16'}, + { + 'make_embeddings': True, + 'write_embeddings': True, + 'tensor_dtype': 'float16', + }, + # Check with all active. + { + 'make_embeddings': True, + 'make_logits': True, + 'make_separated_audio': True, + 'make_frontend': True, + 'write_embeddings': True, + 'write_logits': True, + 'write_separated_audio': True, + 'write_frontend': True, + 'write_raw_audio': True, + }, ) def test_embed_fn( self, - make_embeddings, - make_logits, - make_separated_audio, - write_embeddings, - write_logits, - write_raw_audio, - write_separated_audio, - tensor_dtype, + make_embeddings=False, + make_logits=False, + make_separated_audio=False, + make_frontend=False, + write_embeddings=False, + write_logits=False, + write_raw_audio=False, + write_separated_audio=False, + write_frontend=False, + tensor_dtype='float32', ): model_kwargs = { 'sample_rate': 16000, @@ -76,12 +99,14 @@ def test_embed_fn( 'make_embeddings': make_embeddings, 'make_logits': make_logits, 'make_separated_audio': make_separated_audio, + 'make_frontend': make_frontend, } embed_fn = embed_lib.EmbedFn( write_embeddings=write_embeddings, write_logits=write_logits, write_separated_audio=write_separated_audio, write_raw_audio=write_raw_audio, + write_frontend=write_frontend, model_key='placeholder_model', model_config=model_kwargs, file_id_depth=0, @@ -133,6 +158,14 @@ def test_embed_fn( else: self.assertEqual(got_example[tf_examples.SEPARATED_AUDIO].shape, (0,)) + if make_frontend and write_frontend: + frontend = got_example[tf_examples.FRONTEND] + self.assertSequenceEqual( + frontend.shape, got_example[tf_examples.FRONTEND_SHAPE] + ) + else: + self.assertEqual(got_example[tf_examples.FRONTEND].shape, (0,)) + if write_raw_audio: raw_audio = got_example[tf_examples.RAW_AUDIO] self.assertSequenceEqual( @@ -155,6 +188,7 @@ def test_embed_fn_source_variations(self): write_logits=False, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, model_key='placeholder_model', model_config=model_kwargs, min_audio_s=2.0, @@ -218,6 +252,7 @@ def test_keyed_write_logits(self): write_logits=write_logits, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, model_key='placeholder_model', model_config=model_kwargs, file_id_depth=0, @@ -299,6 +334,7 @@ def test_embed_short_audio(self): write_logits=False, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, model_key='placeholder_model', model_config=model_kwargs, min_audio_s=1.0, @@ -341,6 +377,7 @@ def test_frame_audio(self): write_logits=False, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, model_key='placeholder_model', model_config=model_kwargs, min_audio_s=1.0, @@ -396,6 +433,7 @@ def test_tfrecord_multiwriter(self): write_logits=False, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, ) ) with tf_examples.EmbeddingsTFRecordMultiWriter( @@ -434,6 +472,7 @@ def test_get_existing_source_ids(self): write_logits=False, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, ) ) with tf_examples.EmbeddingsTFRecordMultiWriter( @@ -602,6 +641,7 @@ def test_beam_pipeline(self): write_logits=False, write_separated_audio=False, write_raw_audio=False, + write_frontend=False, model_key='placeholder_model', model_config=model_kwargs, file_id_depth=0, diff --git a/chirp/train/train_utils.py b/chirp/train/train_utils.py index 33dc3ccc..fe3073ee 100644 --- a/chirp/train/train_utils.py +++ b/chirp/train/train_utils.py @@ -37,6 +37,7 @@ TAXONOMY_KEYS = ['genus', 'family', 'order'] + # Note: Inherit from PyTreeNode instead of using the flax.struct.dataclass # to avoid PyType issues. # See: https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html @@ -87,6 +88,7 @@ class ModelBundle(flax.struct.PyTreeNode): class_lists: dict[str, namespace.ClassList] | None = None output_head_metadatas: Sequence[OutputHeadMetadata] | None = None + @flax.struct.dataclass class MultiAverage(clu_metrics.Average): """Computes the average of all values on the last dimension."""