Skip to content

Commit

Permalink
Handle frontend outputs in InferenceOutputs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610767302
  • Loading branch information
sdenton4 authored and copybara-github committed Feb 27, 2024
1 parent 8cc4468 commit 7ec6093
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 51 deletions.
5 changes: 5 additions & 0 deletions chirp/inference/embed_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
write_logits: bool | Sequence[str],
write_separated_audio: bool,
write_raw_audio: bool,
write_frontend: bool,
model_key: str,
model_config: config_dict.ConfigDict,
file_id_depth: int,
Expand All @@ -177,6 +178,8 @@ def __init__(
logit keys to write.
write_separated_audio: Whether to write out separated audio tracks.
write_raw_audio: If true, will add the original audio to the output.
write_frontend: If true, will add the model's frontend (spectrogram) to
the output.
model_key: String indicating which model wrapper to use. See MODEL_KEYS.
Only used for setting up the embedding model.
model_config: Keyword arg dictionary for the model wrapper class. Only
Expand All @@ -201,6 +204,7 @@ def __init__(
self.write_logits = write_logits
self.write_separated_audio = write_separated_audio
self.write_raw_audio = write_raw_audio
self.write_frontend = write_frontend
self.crop_s = crop_s
self.embedding_model = embedding_model
self.file_id_depth = file_id_depth
Expand Down Expand Up @@ -277,6 +281,7 @@ def audio_to_example(
write_separated_audio=self.write_separated_audio,
write_embeddings=self.write_embeddings,
write_logits=write_logits,
write_frontend=self.write_frontend,
tensor_dtype=self.tensor_dtype,
)
return example
Expand Down
37 changes: 16 additions & 21 deletions chirp/inference/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,24 @@ def embed_from_batch_embed_fn(
audio_batch = audio_array[np.newaxis, :]
outputs = embed_fn(audio_batch)

if outputs.embeddings is not None:
embeddings = outputs.embeddings[0]
else:
embeddings = None
unbatched_outputs = {}
for k in ['embeddings', 'separated_audio', 'frontend']:
if getattr(outputs, k) is not None:
unbatched_outputs[k] = getattr(outputs, k)[0]
else:
unbatched_outputs[k] = None

if outputs.logits is not None:
logits = {}
for k, v in outputs.logits.items():
logits[k] = v[0]
else:
logits = None
if outputs.separated_audio is not None:
separated_audio = outputs.separated_audio[0]
else:
separated_audio = None

return InferenceOutputs(
embeddings=embeddings,
logits=logits,
separated_audio=separated_audio,
batched=False,
**unbatched_outputs,
)


Expand All @@ -275,10 +273,13 @@ def batch_embed_from_embed_fn(
outputs = []
for audio in audio_batch:
outputs.append(embed_fn(audio))
if outputs[0].embeddings is not None:
embeddings = np.stack([x.embeddings for x in outputs], axis=0)
else:
embeddings = None

batched_outputs = {}
for k in ['embeddings', 'separated_audio', 'frontend']:
if getattr(outputs[0], k) is not None:
batched_outputs[k] = np.stack([getattr(x, k) for x in outputs], axis=0)
else:
batched_outputs[k] = None

if outputs[0].logits is not None:
batched_logits = {}
Expand All @@ -289,16 +290,10 @@ def batch_embed_from_embed_fn(
else:
batched_logits = None

if outputs[0].separated_audio is not None:
separated_audio = np.stack([x.separated_audio for x in outputs], axis=0)
else:
separated_audio = None

return InferenceOutputs(
embeddings=embeddings,
logits=batched_logits,
separated_audio=separated_audio,
batched=True,
**batched_outputs,
)


Expand Down
6 changes: 6 additions & 0 deletions chirp/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,11 @@ class PlaceholderModel(interface.EmbeddingModel):
make_embeddings: bool = True
make_logits: bool = True
make_separated_audio: bool = True
make_frontend: bool = True
do_frame_audio: bool = False
window_size_s: float = 1.0
hop_size_s: float = 1.0
frontend_size: tuple[int, int] = (32, 32)

@classmethod
def from_config(cls, config: config_dict.ConfigDict) -> 'PlaceholderModel':
Expand All @@ -703,6 +705,10 @@ def embed(self, audio_array: np.ndarray) -> interface.InferenceOutputs:
outputs['embeddings'] = np.zeros(
[time_size, 1, self.embedding_size], np.float32
)
if self.make_frontend:
outputs['frontend'] = np.zeros(
[time_size, self.frontend_size[0], self.frontend_size[1]], np.float32
)
if self.make_logits:
outputs['logits'] = {
'label': np.zeros(
Expand Down
41 changes: 28 additions & 13 deletions chirp/inference/tf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
SEPARATED_AUDIO_SHAPE = 'separated_audio_shape'
RAW_AUDIO = 'raw_audio'
RAW_AUDIO_SHAPE = 'raw_audio_shape'
FRONTEND = 'frontend'
FRONTEND_SHAPE = 'frontend_shape'


def get_feature_description(logit_names: Sequence[str] | None = None):
Expand All @@ -62,6 +64,10 @@ def get_feature_description(logit_names: Sequence[str] | None = None):
SEPARATED_AUDIO_SHAPE: tf.io.FixedLenSequenceFeature(
[], tf.int64, allow_missing=True
),
FRONTEND: tf.io.FixedLenFeature([], tf.string, default_value=''),
FRONTEND_SHAPE: tf.io.FixedLenSequenceFeature(
[], tf.int64, allow_missing=True
),
RAW_AUDIO: tf.io.FixedLenFeature([], tf.string, default_value=''),
RAW_AUDIO_SHAPE: tf.io.FixedLenSequenceFeature(
[], tf.int64, allow_missing=True
Expand All @@ -86,7 +92,7 @@ def get_example_parser(

def _parser(ex):
ex = tf.io.parse_single_example(ex, features)
tensor_keys = [EMBEDDING, SEPARATED_AUDIO, RAW_AUDIO]
tensor_keys = [EMBEDDING, SEPARATED_AUDIO, RAW_AUDIO, FRONTEND]
if logit_names is not None:
tensor_keys.extend(logit_names)
for key in tensor_keys:
Expand Down Expand Up @@ -140,6 +146,7 @@ def model_outputs_to_tf_example(
write_logits: bool | Sequence[str],
write_separated_audio: bool,
write_raw_audio: bool,
write_frontend: bool,
tensor_dtype: str = 'float32',
) -> tf.train.Example:
"""Create a TFExample from InferenceOutputs."""
Expand All @@ -153,6 +160,26 @@ def model_outputs_to_tf_example(
)
feature[EMBEDDING_SHAPE] = (int_feature(model_outputs.embeddings.shape),)

if write_separated_audio and model_outputs.separated_audio is not None:
feature[SEPARATED_AUDIO] = bytes_feature(
serialize_tensor(model_outputs.separated_audio, tensor_dtype)
)
feature[SEPARATED_AUDIO_SHAPE] = int_feature(
model_outputs.separated_audio.shape
)

if write_frontend and model_outputs.frontend is not None:
feature[FRONTEND] = bytes_feature(
serialize_tensor(model_outputs.frontend, tensor_dtype)
)
feature[FRONTEND_SHAPE] = int_feature(model_outputs.frontend.shape)

if write_raw_audio:
feature[RAW_AUDIO] = bytes_feature(
serialize_tensor(tf.constant(audio, dtype=tf.float32), tensor_dtype)
)
feature[RAW_AUDIO_SHAPE] = int_feature(audio.shape)

# Handle writing logits.
if model_outputs.logits is not None and write_logits:
logit_keys = tuple(model_outputs.logits.keys())
Expand All @@ -166,18 +193,6 @@ def model_outputs_to_tf_example(
)
feature[logits_key + '_shape'] = int_feature(logits.shape)

if write_separated_audio and model_outputs.separated_audio is not None:
feature[SEPARATED_AUDIO] = bytes_feature(
serialize_tensor(model_outputs.separated_audio, tensor_dtype)
)
feature[SEPARATED_AUDIO_SHAPE] = int_feature(
model_outputs.separated_audio.shape
)
if write_raw_audio:
feature[RAW_AUDIO] = bytes_feature(
serialize_tensor(tf.constant(audio, dtype=tf.float32), tensor_dtype)
)
feature[RAW_AUDIO_SHAPE] = int_feature(audio.shape)
ex = tf.train.Example(features=tf.train.Features(feature=feature))
return ex

Expand Down
74 changes: 57 additions & 17 deletions chirp/tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,64 @@ def _make_output_head_model(model_path: str, embedding_dim: int = 1280):

class InferenceTest(parameterized.TestCase):

@parameterized.product(
make_embeddings=(True, False),
make_logits=(True, False),
make_separated_audio=(True, False),
write_embeddings=(True, False),
write_logits=(True, False),
write_separated_audio=(True, False),
write_raw_audio=(True, False),
tensor_dtype=('float32', 'float16'),
@parameterized.parameters(
# Test each output type individually.
{'make_embeddings': True},
{'make_embeddings': True, 'write_embeddings': True},
{'make_logits': True},
{'make_logits': True, 'write_logits': True},
{'make_separated_audio': True},
{'make_separated_audio': True, 'write_separated_audio': True},
{'make_frontend': True},
{'make_frontend': True, 'write_frontend': True},
{'write_raw_audio': True},
# Check float16 handling.
{'make_embeddings': True, 'tensor_dtype': 'float16'},
{
'make_embeddings': True,
'write_embeddings': True,
'tensor_dtype': 'float16',
},
# Check with all active.
{
'make_embeddings': True,
'make_logits': True,
'make_separated_audio': True,
'make_frontend': True,
'write_embeddings': True,
'write_logits': True,
'write_separated_audio': True,
'write_frontend': True,
'write_raw_audio': True,
},
)
def test_embed_fn(
self,
make_embeddings,
make_logits,
make_separated_audio,
write_embeddings,
write_logits,
write_raw_audio,
write_separated_audio,
tensor_dtype,
make_embeddings=False,
make_logits=False,
make_separated_audio=False,
make_frontend=False,
write_embeddings=False,
write_logits=False,
write_raw_audio=False,
write_separated_audio=False,
write_frontend=False,
tensor_dtype='float32',
):
model_kwargs = {
'sample_rate': 16000,
'embedding_size': 128,
'make_embeddings': make_embeddings,
'make_logits': make_logits,
'make_separated_audio': make_separated_audio,
'make_frontend': make_frontend,
}
embed_fn = embed_lib.EmbedFn(
write_embeddings=write_embeddings,
write_logits=write_logits,
write_separated_audio=write_separated_audio,
write_raw_audio=write_raw_audio,
write_frontend=write_frontend,
model_key='placeholder_model',
model_config=model_kwargs,
file_id_depth=0,
Expand Down Expand Up @@ -133,6 +158,14 @@ def test_embed_fn(
else:
self.assertEqual(got_example[tf_examples.SEPARATED_AUDIO].shape, (0,))

if make_frontend and write_frontend:
frontend = got_example[tf_examples.FRONTEND]
self.assertSequenceEqual(
frontend.shape, got_example[tf_examples.FRONTEND_SHAPE]
)
else:
self.assertEqual(got_example[tf_examples.FRONTEND].shape, (0,))

if write_raw_audio:
raw_audio = got_example[tf_examples.RAW_AUDIO]
self.assertSequenceEqual(
Expand All @@ -155,6 +188,7 @@ def test_embed_fn_source_variations(self):
write_logits=False,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
model_key='placeholder_model',
model_config=model_kwargs,
min_audio_s=2.0,
Expand Down Expand Up @@ -218,6 +252,7 @@ def test_keyed_write_logits(self):
write_logits=write_logits,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
model_key='placeholder_model',
model_config=model_kwargs,
file_id_depth=0,
Expand Down Expand Up @@ -299,6 +334,7 @@ def test_embed_short_audio(self):
write_logits=False,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
model_key='placeholder_model',
model_config=model_kwargs,
min_audio_s=1.0,
Expand Down Expand Up @@ -341,6 +377,7 @@ def test_frame_audio(self):
write_logits=False,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
model_key='placeholder_model',
model_config=model_kwargs,
min_audio_s=1.0,
Expand Down Expand Up @@ -396,6 +433,7 @@ def test_tfrecord_multiwriter(self):
write_logits=False,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
)
)
with tf_examples.EmbeddingsTFRecordMultiWriter(
Expand Down Expand Up @@ -434,6 +472,7 @@ def test_get_existing_source_ids(self):
write_logits=False,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
)
)
with tf_examples.EmbeddingsTFRecordMultiWriter(
Expand Down Expand Up @@ -602,6 +641,7 @@ def test_beam_pipeline(self):
write_logits=False,
write_separated_audio=False,
write_raw_audio=False,
write_frontend=False,
model_key='placeholder_model',
model_config=model_kwargs,
file_id_depth=0,
Expand Down
2 changes: 2 additions & 0 deletions chirp/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

TAXONOMY_KEYS = ['genus', 'family', 'order']


# Note: Inherit from PyTreeNode instead of using the flax.struct.dataclass
# to avoid PyType issues.
# See: https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html
Expand Down Expand Up @@ -87,6 +88,7 @@ class ModelBundle(flax.struct.PyTreeNode):
class_lists: dict[str, namespace.ClassList] | None = None
output_head_metadatas: Sequence[OutputHeadMetadata] | None = None


@flax.struct.dataclass
class MultiAverage(clu_metrics.Average):
"""Computes the average of all values on the last dimension."""
Expand Down

0 comments on commit 7ec6093

Please sign in to comment.