Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support longer audio contexts #110

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):

def __call__(self, features, *args, **kwargs):
audio_values = [f.pop("audio_values", None) for f in features]

if self.include_alt_fields:
# these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
alt_features = [
Expand All @@ -89,6 +90,7 @@ def __call__(self, features, *args, **kwargs):
}
for f in features
]

input_ids_lens = torch.LongTensor([f["input_ids"].shape[-1] for f in features])
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
Expand All @@ -100,15 +102,14 @@ def __call__(self, features, *args, **kwargs):
# Pad the last dimension of all audio_values to the same length, with 0s on the right.
if audio_values and audio_values[0] is not None:
max_len = max([x.shape[-1] for x in audio_values])
batch["audio_values"] = torch.stack(
batch["audio_values"] = torch.cat(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
)
if self.tokenizer.padding_side == "left":
displacement = batch["input_ids"].shape[-1] - input_ids_lens
batch["audio_token_start_idx"] += displacement.to(
batch["audio_token_start_idx"].device
)

return batch


Expand Down
4 changes: 3 additions & 1 deletion ultravox/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def infer_batch(
inputs = [self._dataproc(s) for s in samples]
for input in inputs:
for key, val in input.items():
input[key] = val.squeeze(0)
if key != "audio_values":
input[key] = val.squeeze(0)

tensors = self.data_collator(inputs)
input_len = tensors["input_ids"].shape[1]
Expand Down Expand Up @@ -198,6 +199,7 @@ def _dataproc(self, sample: datasets.VoiceSample):
text=text_input,
return_tensors="pt",
sampling_rate=SAMPLE_RATE,
audio_context_size=self.model.audio_tower_context_length,
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
if "audio_values" in inputs:
Expand Down
1 change: 1 addition & 0 deletions ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def fake_generate(**kwargs):
)
self.model.device = "cpu"
self.model.generate = mock.MagicMock(side_effect=fake_generate)
self.model.audio_tower_context_length = None


EXPECTED_TOKEN_IDS_START = [128000, 128006, 882, 128007]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest adding a new test to this file to demonstrate the long-audio-context handling in the processor

Expand Down
33 changes: 25 additions & 8 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self, config: UltravoxConfig):
self.vocab_size = config.vocab_size

self.audio_tower = self._create_audio_tower(config)
self.audio_tower_context_length: Optional[int] = None
if config.audio_model_id is not None and "whisper" in config.audio_model_id:
self.audio_tower_context_length = 3000

self.multi_modal_projector = UltravoxProjector(config)
self.language_model = self._create_language_model(config)

Expand Down Expand Up @@ -148,6 +152,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
batch_size: Optional[torch.Tensor] = None,
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
# the alt_* fields are needed for KL divergence loss
alt_input_ids: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -179,27 +184,37 @@ def forward(
inputs_embeds = self.get_input_embeddings().forward(input_ids)

if audio_values is not None:

assert (
audio_token_start_idx is not None and audio_token_len is not None
), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
audio_token_start_idx is not None
and audio_token_len is not None
and batch_size is not None
), "audio_token_start_idx and audio_token_len and batch_size must be provided if audio_values are provided."
assert (
len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
len(audio_token_start_idx) == len(audio_token_len) == len(batch_size)
), "audio_token_start_idx and audio_token_len and batch_size must have the same batch size."

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

audio_embeds = self.multi_modal_projector.forward(audio_tower_output)

# combine audio and text embeddings
for i, (audio, start, length) in enumerate(
zip(audio_embeds, audio_token_start_idx, audio_token_len)
audio_ind = 0
for i, (start, length, audio_batch_size) in enumerate(
zip(audio_token_start_idx, audio_token_len, batch_size)
):
audio = torch.cat(
[
audio_embeds[k]
for k in range(audio_ind, audio_ind + audio_batch_size)
],
dim=0,
)
length = min(length, audio.shape[0])
inputs_embeds[i, start : start + length] = audio[:length]
audio_ind += audio_batch_size

lm_output = self.language_model.forward(
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -234,6 +249,7 @@ def prepare_inputs_for_generation(
audio_values: Optional[torch.FloatTensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
batch_size: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -262,6 +278,7 @@ def prepare_inputs_for_generation(
audio_token_start_idx - prefill_start_idx
)
model_input["audio_token_len"] = audio_token_len
model_input["batch_size"] = batch_size

return model_input

Expand Down
36 changes: 32 additions & 4 deletions ultravox/model/ultravox_processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
import transformers

from .ultravox_config import UltravoxConfig
Expand Down Expand Up @@ -89,6 +90,7 @@ def __call__(
text: Optional[str] = None,
audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
sampling_rate: Optional[int] = None,
audio_context_size: Optional[int] = None,
return_tensors: Optional[
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
Union[str, transformers.TensorType]
] = transformers.TensorType.PYTORCH,
Expand Down Expand Up @@ -132,7 +134,7 @@ def __call__(
- **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
"""
# TODO: Add support for multiple audio and text inputs.
data = {}
data: Dict[str, Any] = {}
audio_embed_frames = 0
if audio is not None and len(audio) > 0:
if self.audio_padding == "max_length":
Expand All @@ -141,6 +143,7 @@ def __call__(
audio_len = 30 * sampling_rate
else:
audio_len = audio.shape[-1]

# It's guaranteed that the number of frames is less than or equal to this amount.
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
Expand All @@ -157,9 +160,34 @@ def __call__(
**kwargs,
)
if "input_features" in x:
data["audio_values"] = x.input_features
audio_values = x.input_features
else:
data["audio_values"] = x.input_values
audio_values = x.input_values

audio_values = torch.tensor(audio_values)
if audio_context_size and audio_values.shape[-1] > audio_context_size:
audio_values_chunks = list(
torch.split(
audio_values,
audio_context_size,
dim=len(audio_values.shape) - 1,
)
)
# Pad the last chunk to match audio_context_size
last_chunk = audio_values_chunks[-1]
pad_size = audio_context_size - last_chunk.shape[-1]
if pad_size > 0:
# Pad only the last dimension (T) in B,D,T format
audio_values_chunks[-1] = F.pad(
last_chunk, (0, pad_size, 0, 0, 0, 0)
)
else:
audio_values_chunks = [audio_values]

data["audio_values"] = torch.cat(audio_values_chunks)
num_audio_chunks = data["audio_values"].shape[0]

data["batch_size"] = [num_audio_chunks]

if text is not None:
assert isinstance(
Expand Down
Loading