Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support longer audio contexts #110

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):

def __call__(self, features, *args, **kwargs):
audio_values = [f.pop("audio_values", None) for f in features]

if self.include_alt_fields:
# these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
alt_features = [
Expand All @@ -89,6 +90,7 @@ def __call__(self, features, *args, **kwargs):
}
for f in features
]

input_ids_lens = torch.LongTensor([f["input_ids"].shape[-1] for f in features])
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
Expand All @@ -100,15 +102,14 @@ def __call__(self, features, *args, **kwargs):
# Pad the last dimension of all audio_values to the same length, with 0s on the right.
if audio_values and audio_values[0] is not None:
max_len = max([x.shape[-1] for x in audio_values])
batch["audio_values"] = torch.stack(
batch["audio_values"] = torch.cat(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
)
if self.tokenizer.padding_side == "left":
displacement = batch["input_ids"].shape[-1] - input_ids_lens
batch["audio_token_start_idx"] += displacement.to(
batch["audio_token_start_idx"].device
)

return batch


Expand Down
3 changes: 2 additions & 1 deletion ultravox/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def infer_batch(
inputs = [self._dataproc(s) for s in samples]
for input in inputs:
for key, val in input.items():
input[key] = val.squeeze(0)
if key != "audio_values":
input[key] = val.squeeze(0)

tensors = self.data_collator(inputs)
input_len = tensors["input_ids"].shape[1]
Expand Down
36 changes: 32 additions & 4 deletions ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -30,16 +31,23 @@ def audio_processor():
)


@pytest.fixture(scope="module")
def audio_processor_whisper():
return transformers.AutoProcessor.from_pretrained("openai/whisper-tiny")


class FakeInference(infer.LocalInference):
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer,
audio_processor: transformers.ProcessorMixin,
audio_context_size: Optional[int] = None,
):
def fake_generate(**kwargs):
input = kwargs.get("input_ids")
input_len = input.shape[1] if input is not None else 0
output = transformers.generation.utils.GenerateDecoderOnlyOutput(
sequences=[range(25)]
sequences=[range(input_len + 5)] # Always output 5 tokens
)
streamer = kwargs.get("streamer", None)
if streamer:
Expand All @@ -49,7 +57,7 @@ def fake_generate(**kwargs):
return output

processor = ultravox_processing.UltravoxProcessor(
audio_processor, tokenizer=tokenizer
audio_processor, tokenizer=tokenizer, audio_context_size=audio_context_size
)
super().__init__(
mock.MagicMock(),
Expand All @@ -66,6 +74,26 @@ def fake_generate(**kwargs):
EXPECTED_TOKEN_IDS_END = [128009, 128006, 78191, 128007, 271]


def test_long_audio_context(tokenizer, audio_processor_whisper):
"""Ensure we handle long audio context properly."""
inference = FakeInference(
tokenizer, audio_processor_whisper, audio_context_size=3000
)
array = np.ones(960000, dtype=np.float32)
sample = datasets.VoiceSample.from_prompt_and_raw(
"Transcribe\n<|audio|>", array, 16000
)
output = inference.infer(sample)
assert output.input_tokens == 388
assert output.output_tokens == 5
assert output.text == "ers on conapub"
generate_args = inference.model.generate.call_args[1]
assert generate_args["audio_values"].shape == (2, 80, 3000)
assert generate_args["audio_token_len"].item() == torch.tensor(375)
assert generate_args["audio_token_start_idx"] == torch.tensor(8)
assert generate_args["audio_batch_size"] == torch.tensor(2)


def test_infer_16kHz(tokenizer, audio_processor):
"""Ensure we handle 16kHz float32 audio properly."""
inference = FakeInference(tokenizer, audio_processor)
Expand Down Expand Up @@ -148,8 +176,8 @@ def test_infer_text_only(tokenizer, audio_processor):
sample = datasets.VoiceSample.from_prompt("Hello?")
output = inference.infer(sample)
assert output.input_tokens == 12
assert output.output_tokens == 13
assert output.text == "-./0123456789"
assert output.output_tokens == 5
assert output.text == "-./01"
generate_args = inference.model.generate.call_args[1]
assert generate_args.get("audio_values") is None
call_input_ids = generate_args["input_ids"]
Expand Down
5 changes: 4 additions & 1 deletion ultravox/inference/ultravox_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def __init__(
)

processor = ultravox_processing.UltravoxProcessor(
audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor
audio_processor,
tokenizer=tokenizer,
stack_factor=model.config.stack_factor,
audio_context_size=model.audio_tower_context_length,
)

super().__init__(
Expand Down
32 changes: 24 additions & 8 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self, config: UltravoxConfig):
self.vocab_size = config.vocab_size

self.audio_tower = self._create_audio_tower(config)
self.audio_tower_context_length: Optional[int] = None
if config.audio_model_id is not None and "whisper" in config.audio_model_id:
self.audio_tower_context_length = 3000

self.multi_modal_projector = self._create_multi_modal_projector(config)
self.language_model = self._create_language_model(config)

Expand Down Expand Up @@ -155,6 +159,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
audio_batch_size: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
# the alt_* fields are needed for KL divergence loss
alt_input_ids: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -186,27 +191,36 @@ def forward(
inputs_embeds = self.get_input_embeddings().forward(input_ids)

if audio_values is not None:

assert (
audio_token_start_idx is not None and audio_token_len is not None
), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
audio_token_start_idx is not None
and audio_token_len is not None
and audio_batch_size is not None
), "audio_token_start_idx and audio_token_len and audio_batch_size must be provided if audio_values are provided."
assert (
len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
len(audio_token_start_idx)
== len(audio_token_len)
== len(audio_batch_size)
), "audio_token_start_idx and audio_token_len and audio_batch_size must have the same batch size."

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values.to(self.audio_tower.dtype)
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

audio_embeds = self.multi_modal_projector.forward(audio_tower_output)

# combine audio and text embeddings
for i, (audio, start, length) in enumerate(
zip(audio_embeds, audio_token_start_idx, audio_token_len)
audio_ind = 0
for i, (start, length, batch_size) in enumerate(
zip(audio_token_start_idx, audio_token_len, audio_batch_size)
):
audio = torch.cat(
[audio_embeds[k] for k in range(audio_ind, audio_ind + batch_size)],
dim=0,
)
length = min(length, audio.shape[0])
inputs_embeds[i, start : start + length] = audio[:length]
audio_ind += batch_size

lm_output = self.language_model.forward(
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -241,6 +255,7 @@ def prepare_inputs_for_generation(
audio_values: Optional[torch.FloatTensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
audio_batch_size: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -269,6 +284,7 @@ def prepare_inputs_for_generation(
audio_token_start_idx - prefill_start_idx
)
model_input["audio_token_len"] = audio_token_len
model_input["audio_batch_size"] = audio_batch_size

return model_input

Expand Down
1 change: 1 addition & 0 deletions ultravox/model/ultravox_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
audio_processor=audio_processor,
tokenizer=tokenizer,
stack_factor=model.config.stack_factor,
audio_context_size=model.audio_tower_context_length,
)

super().__init__(model=model, tokenizer=tokenizer, **kwargs)
Expand Down
42 changes: 38 additions & 4 deletions ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
import transformers

from .ultravox_config import UltravoxConfig
Expand Down Expand Up @@ -38,6 +39,9 @@ def __init__(
encoder_ds_factor: int = 320,
stack_factor: int = 8,
audio_placeholder: str = "<|audio|>",
audio_context_size: Optional[
int
] = 3000, # Defaults to whisper encoder context size
):
"""
Args:
Expand All @@ -53,6 +57,7 @@ def __init__(
self.stack_factor = stack_factor
self.audio_placeholder = audio_placeholder
self.audio_token_replacement = tokenizer.eos_token
self.audio_context_size = audio_context_size
assert (
self.audio_token_replacement is not None
), "The tokenizer has no EOS token. Cannot recover."
Expand Down Expand Up @@ -132,7 +137,7 @@ def __call__(
- **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
"""
# TODO: Add support for multiple audio and text inputs.
data = {}
data: Dict[str, Any] = {}
audio_embed_frames = 0
if audio is not None and len(audio) > 0:
if self.audio_padding == "max_length":
Expand All @@ -141,6 +146,7 @@ def __call__(
audio_len = 30 * sampling_rate
else:
audio_len = audio.shape[-1]

# It's guaranteed that the number of frames is less than or equal to this amount.
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
Expand All @@ -157,9 +163,37 @@ def __call__(
**kwargs,
)
if "input_features" in x:
data["audio_values"] = x.input_features
audio_values = x.input_features
else:
data["audio_values"] = x.input_values
audio_values = x.input_values

audio_values = torch.tensor(audio_values)
if (
self.audio_context_size
and audio_values.shape[-1] > self.audio_context_size
):
audio_values_chunks = list(
torch.split(
audio_values,
self.audio_context_size,
dim=len(audio_values.shape) - 1,
)
)
# Pad the last chunk to match audio_context_size
last_chunk = audio_values_chunks[-1]
pad_size = self.audio_context_size - last_chunk.shape[-1]
if pad_size > 0:
# Pad only the last dimension (T) in B,D,T format
audio_values_chunks[-1] = F.pad(
last_chunk, (0, pad_size, 0, 0, 0, 0)
)
else:
audio_values_chunks = [audio_values]

data["audio_values"] = torch.cat(audio_values_chunks)
num_audio_chunks = data["audio_values"].shape[0]

data["audio_batch_size"] = [num_audio_chunks]

if text is not None:
assert isinstance(
Expand Down
7 changes: 6 additions & 1 deletion ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def train(args: config_base.TrainConfig):
text_tokenizer.padding_side = "right"
text_tokenizer.pad_token = text_tokenizer.eos_token
audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model)
processor = ultravox_processing.UltravoxProcessor(audio_processor, text_tokenizer)

# Instantiate the model and processor
config = ultravox_config.UltravoxConfig(
Expand All @@ -142,6 +141,12 @@ def train(args: config_base.TrainConfig):
with model_load_context:
model = ultravox_model.UltravoxModel(config)

processor = ultravox_processing.UltravoxProcessor(
audio_processor,
text_tokenizer,
audio_context_size=model.audio_tower_context_length,
)

assert model.get_input_embeddings().num_embeddings == len(
text_tokenizer
), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}"
Expand Down
Loading