From 8883907aade2f4e7e4f0ca1e045af53c2539a600 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Wed, 11 Sep 2024 00:39:14 -0700 Subject: [PATCH 1/8] Support longer audio contexts --- ultravox/model/ultravox_model.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index cb209888..637499da 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -51,6 +51,8 @@ def __init__(self, config: UltravoxConfig): self.vocab_size = config.vocab_size self.audio_tower = self._create_audio_tower(config) + self.audio_tower_context_length = 3000 # the context window for the whisper model + self.multi_modal_projector = UltravoxProjector(config) self.language_model = self._create_language_model(config) @@ -186,17 +188,30 @@ def forward( len(audio_token_start_idx) == len(audio_token_len) == len(audio_values) ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size." - # B x A/3200 x D - audio_tower_output = self.audio_tower.forward( - audio_values - ).last_hidden_state - audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) + # Check if the audio_values T dimension is greater than whisper encoder's context window. + print("audio values shape: ", audio_values.shape) + if audio_values.shape[2] > self.audio_tower_context_length: + audio_values_chunks = torch.split( + audio_values, self.audio_tower_context_length, dim=2 + ) + else: + audio_values_chunks = (audio_values,) + + rebuilt_audio_embeds = [] + for audio_chunk in audio_values_chunks: + # B x A/3200 x D + audio_tower_output = self.audio_tower.forward( + audio_chunk + ).last_hidden_state + audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) + audio_embeds = self.multi_modal_projector.forward(audio_tower_output) + rebuilt_audio_embeds.append(audio_embeds) - audio_embeds = self.multi_modal_projector.forward(audio_tower_output) + rebuilt_audio_embeds_tensor = torch.cat(rebuilt_audio_embeds, dim=1) # combine audio and text embeddings for i, (audio, start, length) in enumerate( - zip(audio_embeds, audio_token_start_idx, audio_token_len) + zip(rebuilt_audio_embeds_tensor, audio_token_start_idx, audio_token_len) ): length = min(length, audio.shape[0]) inputs_embeds[i, start : start + length] = audio[:length] From 7672177e5c9ff58ba978cc1834382aacc4ec15e2 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Wed, 11 Sep 2024 00:42:18 -0700 Subject: [PATCH 2/8] Formatting --- ultravox/model/ultravox_model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 637499da..a908f53c 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -51,8 +51,10 @@ def __init__(self, config: UltravoxConfig): self.vocab_size = config.vocab_size self.audio_tower = self._create_audio_tower(config) - self.audio_tower_context_length = 3000 # the context window for the whisper model - + self.audio_tower_context_length = ( + 3000 # the context window for the whisper model + ) + self.multi_modal_projector = UltravoxProjector(config) self.language_model = self._create_language_model(config) @@ -189,7 +191,6 @@ def forward( ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size." # Check if the audio_values T dimension is greater than whisper encoder's context window. - print("audio values shape: ", audio_values.shape) if audio_values.shape[2] > self.audio_tower_context_length: audio_values_chunks = torch.split( audio_values, self.audio_tower_context_length, dim=2 From e2e811505f2525f1c10f0a4c2105ff0199a226ca Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 13 Sep 2024 00:26:55 -0700 Subject: [PATCH 3/8] Working --- ultravox/data/datasets.py | 5 +-- ultravox/inference/infer.py | 4 ++- ultravox/model/ultravox_model.py | 51 +++++++++++++-------------- ultravox/model/ultravox_processing.py | 29 +++++++++++++-- 4 files changed, 58 insertions(+), 31 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 1ea40c95..6acf09b2 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -79,6 +79,7 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): def __call__(self, features, *args, **kwargs): audio_values = [f.pop("audio_values", None) for f in features] + if self.include_alt_fields: # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method alt_features = [ @@ -89,6 +90,7 @@ def __call__(self, features, *args, **kwargs): } for f in features ] + input_ids_lens = torch.LongTensor([f["input_ids"].shape[-1] for f in features]) batch = super().__call__(features, *args, **kwargs) if self.include_alt_fields: @@ -100,7 +102,7 @@ def __call__(self, features, *args, **kwargs): # Pad the last dimension of all audio_values to the same length, with 0s on the right. if audio_values and audio_values[0] is not None: max_len = max([x.shape[-1] for x in audio_values]) - batch["audio_values"] = torch.stack( + batch["audio_values"] = torch.cat( [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values] ) if self.tokenizer.padding_side == "left": @@ -108,7 +110,6 @@ def __call__(self, features, *args, **kwargs): batch["audio_token_start_idx"] += displacement.to( batch["audio_token_start_idx"].device ) - return batch diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index f117c6ee..584b7b87 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -111,7 +111,8 @@ def infer_batch( inputs = [self._dataproc(s) for s in samples] for input in inputs: for key, val in input.items(): - input[key] = val.squeeze(0) + if key != "audio_values": + input[key] = val.squeeze(0) tensors = self.data_collator(inputs) input_len = tensors["input_ids"].shape[1] @@ -198,6 +199,7 @@ def _dataproc(self, sample: datasets.VoiceSample): text=text_input, return_tensors="pt", sampling_rate=SAMPLE_RATE, + audio_context_size=self.model.audio_tower_context_length, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} if "audio_values" in inputs: diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index a908f53c..08b9c703 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -51,9 +51,9 @@ def __init__(self, config: UltravoxConfig): self.vocab_size = config.vocab_size self.audio_tower = self._create_audio_tower(config) - self.audio_tower_context_length = ( - 3000 # the context window for the whisper model - ) + self.audio_tower_context_length = None + if "whisper" in config.audio_model_id is not None: + self.audio_tower_context_length = 3000 self.multi_modal_projector = UltravoxProjector(config) self.language_model = self._create_language_model(config) @@ -152,6 +152,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, + batch_size: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, # the alt_* fields are needed for KL divergence loss alt_input_ids: Optional[torch.Tensor] = None, @@ -183,39 +184,35 @@ def forward( inputs_embeds = self.get_input_embeddings().forward(input_ids) if audio_values is not None: + assert ( audio_token_start_idx is not None and audio_token_len is not None ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided." assert ( - len(audio_token_start_idx) == len(audio_token_len) == len(audio_values) - ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size." - - # Check if the audio_values T dimension is greater than whisper encoder's context window. - if audio_values.shape[2] > self.audio_tower_context_length: - audio_values_chunks = torch.split( - audio_values, self.audio_tower_context_length, dim=2 - ) - else: - audio_values_chunks = (audio_values,) - - rebuilt_audio_embeds = [] - for audio_chunk in audio_values_chunks: - # B x A/3200 x D - audio_tower_output = self.audio_tower.forward( - audio_chunk - ).last_hidden_state - audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) - audio_embeds = self.multi_modal_projector.forward(audio_tower_output) - rebuilt_audio_embeds.append(audio_embeds) + len(audio_token_start_idx) == len(audio_token_len) == len(batch_size) + ), "audio_token_start_idx and audio_token_len must have the same batch size." - rebuilt_audio_embeds_tensor = torch.cat(rebuilt_audio_embeds, dim=1) + audio_tower_output = self.audio_tower.forward( + audio_values + ).last_hidden_state + audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) + audio_embeds = self.multi_modal_projector.forward(audio_tower_output) # combine audio and text embeddings - for i, (audio, start, length) in enumerate( - zip(rebuilt_audio_embeds_tensor, audio_token_start_idx, audio_token_len) + audio_ind = 0 + for i, (start, length, audio_batch_size) in enumerate( + zip(audio_token_start_idx, audio_token_len, batch_size) ): + audio = torch.cat( + [ + audio_embeds[k] + for k in range(audio_ind, audio_ind + audio_batch_size) + ], + dim=0, + ) length = min(length, audio.shape[0]) inputs_embeds[i, start : start + length] = audio[:length] + audio_ind += audio_batch_size lm_output = self.language_model.forward( inputs_embeds=inputs_embeds, @@ -250,6 +247,7 @@ def prepare_inputs_for_generation( audio_values: Optional[torch.FloatTensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, + batch_size: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -278,6 +276,7 @@ def prepare_inputs_for_generation( audio_token_start_idx - prefill_start_idx ) model_input["audio_token_len"] = audio_token_len + model_input["batch_size"] = batch_size return model_input diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 211f7f0a..3fec458c 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -2,6 +2,7 @@ import numpy as np import torch +import torch.nn.functional as F import transformers from .ultravox_config import UltravoxConfig @@ -89,6 +90,7 @@ def __call__( text: Optional[str] = None, audio: Optional[Union[np.ndarray, torch.Tensor]] = None, sampling_rate: Optional[int] = None, + audio_context_size: Optional[int] = None, return_tensors: Optional[ Union[str, transformers.TensorType] ] = transformers.TensorType.PYTORCH, @@ -141,6 +143,7 @@ def __call__( audio_len = 30 * sampling_rate else: audio_len = audio.shape[-1] + # It's guaranteed that the number of frames is less than or equal to this amount. # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound. # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings. @@ -157,9 +160,30 @@ def __call__( **kwargs, ) if "input_features" in x: - data["audio_values"] = x.input_features + audio_values = x.input_features + else: + audio_values = x.input_values + + audio_values = torch.from_numpy(audio_values) + if audio_context_size and audio_values.shape[2] > audio_context_size: + audio_values_chunks = list( + torch.split(audio_values, audio_context_size, dim=2) + ) + # Pad the last chunk to match audio_context_size + last_chunk = audio_values_chunks[-1] + pad_size = audio_context_size - last_chunk.shape[2] + if pad_size > 0: + # Pad only the last dimension (T) in B,D,T format + audio_values_chunks[-1] = F.pad( + last_chunk, (0, pad_size, 0, 0, 0, 0) + ) else: - data["audio_values"] = x.input_values + audio_values_chunks = [audio_values] + + data["audio_values"] = torch.cat(audio_values_chunks, dim=0) + num_audio_chunks = data["audio_values"].shape[0] + + data["batch_size"] = [num_audio_chunks] if text is not None: assert isinstance( @@ -177,6 +201,7 @@ def __call__( add_special_tokens=False, ) ) + data["audio_token_start_idx"] = [start_idx] # Replace the audio placeholder with the audio token. From a5e2ac6dfec672f17d0a8363b2c4aceadce2d694 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 13 Sep 2024 01:09:33 -0700 Subject: [PATCH 4/8] Fix some tests --- ultravox/inference/infer_test.py | 1 + ultravox/model/ultravox_model.py | 12 +++++++----- ultravox/model/ultravox_processing.py | 20 +++++++++++++------- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index f6a06c74..cf6b0609 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -60,6 +60,7 @@ def fake_generate(**kwargs): ) self.model.device = "cpu" self.model.generate = mock.MagicMock(side_effect=fake_generate) + self.model.audio_tower_context_length = 3000 EXPECTED_TOKEN_IDS_START = [128000, 128006, 882, 128007] diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 08b9c703..10d5a8af 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -51,8 +51,8 @@ def __init__(self, config: UltravoxConfig): self.vocab_size = config.vocab_size self.audio_tower = self._create_audio_tower(config) - self.audio_tower_context_length = None - if "whisper" in config.audio_model_id is not None: + self.audio_tower_context_length: Optional[int] = None + if config.audio_model_id is not None and "whisper" in config.audio_model_id: self.audio_tower_context_length = 3000 self.multi_modal_projector = UltravoxProjector(config) @@ -186,11 +186,13 @@ def forward( if audio_values is not None: assert ( - audio_token_start_idx is not None and audio_token_len is not None - ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided." + audio_token_start_idx is not None + and audio_token_len is not None + and batch_size is not None + ), "audio_token_start_idx and audio_token_len and batch_size must be provided if audio_values are provided." assert ( len(audio_token_start_idx) == len(audio_token_len) == len(batch_size) - ), "audio_token_start_idx and audio_token_len must have the same batch size." + ), "audio_token_start_idx and audio_token_len and batch_size must have the same batch size." audio_tower_output = self.audio_tower.forward( audio_values diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 3fec458c..c54528e1 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import numpy as np import torch @@ -134,7 +134,7 @@ def __call__( - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`. """ # TODO: Add support for multiple audio and text inputs. - data = {} + data: Dict[str, Any] = {} audio_embed_frames = 0 if audio is not None and len(audio) > 0: if self.audio_padding == "max_length": @@ -164,14 +164,20 @@ def __call__( else: audio_values = x.input_values - audio_values = torch.from_numpy(audio_values) - if audio_context_size and audio_values.shape[2] > audio_context_size: + audio_values = torch.tensor(audio_values) + print("audio values shape", audio_values.shape) + print("audio_context_size", audio_context_size) + if audio_context_size and audio_values.shape[-1] > audio_context_size: audio_values_chunks = list( - torch.split(audio_values, audio_context_size, dim=2) + torch.split( + audio_values, + audio_context_size, + dim=len(audio_values.shape) - 1, + ) ) # Pad the last chunk to match audio_context_size last_chunk = audio_values_chunks[-1] - pad_size = audio_context_size - last_chunk.shape[2] + pad_size = audio_context_size - last_chunk.shape[-1] if pad_size > 0: # Pad only the last dimension (T) in B,D,T format audio_values_chunks[-1] = F.pad( @@ -180,7 +186,7 @@ def __call__( else: audio_values_chunks = [audio_values] - data["audio_values"] = torch.cat(audio_values_chunks, dim=0) + data["audio_values"] = torch.cat(audio_values_chunks) num_audio_chunks = data["audio_values"].shape[0] data["batch_size"] = [num_audio_chunks] From 495b894ae954746a85d0079bb947cc6e42b18935 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 13 Sep 2024 01:18:09 -0700 Subject: [PATCH 5/8] Fix some more tests --- ultravox/inference/infer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index cf6b0609..b0cc7a17 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -60,7 +60,7 @@ def fake_generate(**kwargs): ) self.model.device = "cpu" self.model.generate = mock.MagicMock(side_effect=fake_generate) - self.model.audio_tower_context_length = 3000 + self.model.audio_tower_context_length = None EXPECTED_TOKEN_IDS_START = [128000, 128006, 882, 128007] From 0275766e81911d8c4e77b961b4820cbb88724281 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Fri, 13 Sep 2024 01:20:48 -0700 Subject: [PATCH 6/8] Remove prints --- ultravox/model/ultravox_processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index c54528e1..b80d84d3 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -165,8 +165,6 @@ def __call__( audio_values = x.input_values audio_values = torch.tensor(audio_values) - print("audio values shape", audio_values.shape) - print("audio_context_size", audio_context_size) if audio_context_size and audio_values.shape[-1] > audio_context_size: audio_values_chunks = list( torch.split( @@ -207,7 +205,6 @@ def __call__( add_special_tokens=False, ) ) - data["audio_token_start_idx"] = [start_idx] # Replace the audio placeholder with the audio token. From cc8ec190bed7b57b91acd6f897d2f479a2692c18 Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Thu, 19 Sep 2024 14:35:49 -0700 Subject: [PATCH 7/8] Move audio context length into processor init --- ultravox/inference/infer.py | 1 - ultravox/inference/infer_test.py | 3 +-- ultravox/inference/ultravox_infer.py | 5 ++++- ultravox/model/ultravox_pipeline.py | 1 + ultravox/model/ultravox_processing.py | 14 ++++++++++---- ultravox/training/train.py | 7 ++++++- 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index 584b7b87..b2f8c068 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -199,7 +199,6 @@ def _dataproc(self, sample: datasets.VoiceSample): text=text_input, return_tensors="pt", sampling_rate=SAMPLE_RATE, - audio_context_size=self.model.audio_tower_context_length, ) inputs = {k: v.to(self.model.device) for k, v in inputs.items()} if "audio_values" in inputs: diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index b0cc7a17..c44c1706 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -49,7 +49,7 @@ def fake_generate(**kwargs): return output processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer + audio_processor, tokenizer=tokenizer, audio_context_size=None ) super().__init__( mock.MagicMock(), @@ -60,7 +60,6 @@ def fake_generate(**kwargs): ) self.model.device = "cpu" self.model.generate = mock.MagicMock(side_effect=fake_generate) - self.model.audio_tower_context_length = None EXPECTED_TOKEN_IDS_START = [128000, 128006, 882, 128007] diff --git a/ultravox/inference/ultravox_infer.py b/ultravox/inference/ultravox_infer.py index 6765ece1..87911fcb 100644 --- a/ultravox/inference/ultravox_infer.py +++ b/ultravox/inference/ultravox_infer.py @@ -58,7 +58,10 @@ def __init__( ) processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor + audio_processor, + tokenizer=tokenizer, + stack_factor=model.config.stack_factor, + audio_context_size=model.audio_tower_context_length, ) super().__init__( diff --git a/ultravox/model/ultravox_pipeline.py b/ultravox/model/ultravox_pipeline.py index c9a8aaa1..33bff932 100644 --- a/ultravox/model/ultravox_pipeline.py +++ b/ultravox/model/ultravox_pipeline.py @@ -37,6 +37,7 @@ def __init__( audio_processor=audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor, + audio_context_size=model.audio_tower_context_length, ) super().__init__(model=model, tokenizer=tokenizer, **kwargs) diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index b80d84d3..b1a1d207 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -39,6 +39,9 @@ def __init__( encoder_ds_factor: int = 320, stack_factor: int = 8, audio_placeholder: str = "<|audio|>", + audio_context_size: Optional[ + int + ] = 3000, # Defaults to whisper encoder context size ): """ Args: @@ -54,6 +57,7 @@ def __init__( self.stack_factor = stack_factor self.audio_placeholder = audio_placeholder self.audio_token_replacement = tokenizer.eos_token + self.audio_context_size = audio_context_size assert ( self.audio_token_replacement is not None ), "The tokenizer has no EOS token. Cannot recover." @@ -90,7 +94,6 @@ def __call__( text: Optional[str] = None, audio: Optional[Union[np.ndarray, torch.Tensor]] = None, sampling_rate: Optional[int] = None, - audio_context_size: Optional[int] = None, return_tensors: Optional[ Union[str, transformers.TensorType] ] = transformers.TensorType.PYTORCH, @@ -165,17 +168,20 @@ def __call__( audio_values = x.input_values audio_values = torch.tensor(audio_values) - if audio_context_size and audio_values.shape[-1] > audio_context_size: + if ( + self.audio_context_size + and audio_values.shape[-1] > self.audio_context_size + ): audio_values_chunks = list( torch.split( audio_values, - audio_context_size, + self.audio_context_size, dim=len(audio_values.shape) - 1, ) ) # Pad the last chunk to match audio_context_size last_chunk = audio_values_chunks[-1] - pad_size = audio_context_size - last_chunk.shape[-1] + pad_size = self.audio_context_size - last_chunk.shape[-1] if pad_size > 0: # Pad only the last dimension (T) in B,D,T format audio_values_chunks[-1] = F.pad( diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 05cea992..ee79934b 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -117,7 +117,6 @@ def train(args: config_base.TrainConfig): text_tokenizer.padding_side = "right" text_tokenizer.pad_token = text_tokenizer.eos_token audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model) - processor = ultravox_processing.UltravoxProcessor(audio_processor, text_tokenizer) # Instantiate the model and processor config = ultravox_config.UltravoxConfig( @@ -134,6 +133,12 @@ def train(args: config_base.TrainConfig): with ddp_utils.run_on_master_first(is_master): model = ultravox_model.UltravoxModel(config) + processor = ultravox_processing.UltravoxProcessor( + audio_processor, + text_tokenizer, + audio_context_size=model.audio_tower_context_length, + ) + assert model.get_input_embeddings().num_embeddings == len( text_tokenizer ), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}" From d427b38d53995f80918dded7be7010f3938c24cf Mon Sep 17 00:00:00 2001 From: Patrick Li Date: Thu, 19 Sep 2024 16:11:30 -0700 Subject: [PATCH 8/8] Add tests --- ultravox/inference/infer_test.py | 36 ++++++++++++++++++++++++--- ultravox/model/ultravox_model.py | 27 ++++++++++---------- ultravox/model/ultravox_processing.py | 2 +- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index c44c1706..dee1f125 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -1,3 +1,4 @@ +from typing import Optional from unittest import mock import numpy as np @@ -30,16 +31,23 @@ def audio_processor(): ) +@pytest.fixture(scope="module") +def audio_processor_whisper(): + return transformers.AutoProcessor.from_pretrained("openai/whisper-tiny") + + class FakeInference(infer.LocalInference): def __init__( self, tokenizer: transformers.PreTrainedTokenizer, audio_processor: transformers.ProcessorMixin, + audio_context_size: Optional[int] = None, ): def fake_generate(**kwargs): input = kwargs.get("input_ids") + input_len = input.shape[1] if input is not None else 0 output = transformers.generation.utils.GenerateDecoderOnlyOutput( - sequences=[range(25)] + sequences=[range(input_len + 5)] # Always output 5 tokens ) streamer = kwargs.get("streamer", None) if streamer: @@ -49,7 +57,7 @@ def fake_generate(**kwargs): return output processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer, audio_context_size=None + audio_processor, tokenizer=tokenizer, audio_context_size=audio_context_size ) super().__init__( mock.MagicMock(), @@ -66,6 +74,26 @@ def fake_generate(**kwargs): EXPECTED_TOKEN_IDS_END = [128009, 128006, 78191, 128007, 271] +def test_long_audio_context(tokenizer, audio_processor_whisper): + """Ensure we handle long audio context properly.""" + inference = FakeInference( + tokenizer, audio_processor_whisper, audio_context_size=3000 + ) + array = np.ones(960000, dtype=np.float32) + sample = datasets.VoiceSample.from_prompt_and_raw( + "Transcribe\n<|audio|>", array, 16000 + ) + output = inference.infer(sample) + assert output.input_tokens == 388 + assert output.output_tokens == 5 + assert output.text == "ers on conapub" + generate_args = inference.model.generate.call_args[1] + assert generate_args["audio_values"].shape == (2, 80, 3000) + assert generate_args["audio_token_len"].item() == torch.tensor(375) + assert generate_args["audio_token_start_idx"] == torch.tensor(8) + assert generate_args["audio_batch_size"] == torch.tensor(2) + + def test_infer_16kHz(tokenizer, audio_processor): """Ensure we handle 16kHz float32 audio properly.""" inference = FakeInference(tokenizer, audio_processor) @@ -148,8 +176,8 @@ def test_infer_text_only(tokenizer, audio_processor): sample = datasets.VoiceSample.from_prompt("Hello?") output = inference.infer(sample) assert output.input_tokens == 12 - assert output.output_tokens == 13 - assert output.text == "-./0123456789" + assert output.output_tokens == 5 + assert output.text == "-./01" generate_args = inference.model.generate.call_args[1] assert generate_args.get("audio_values") is None call_input_ids = generate_args["input_ids"] diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index 10d5a8af..97b6c3d1 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -152,7 +152,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, - batch_size: Optional[torch.Tensor] = None, + audio_batch_size: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, # the alt_* fields are needed for KL divergence loss alt_input_ids: Optional[torch.Tensor] = None, @@ -188,11 +188,13 @@ def forward( assert ( audio_token_start_idx is not None and audio_token_len is not None - and batch_size is not None - ), "audio_token_start_idx and audio_token_len and batch_size must be provided if audio_values are provided." + and audio_batch_size is not None + ), "audio_token_start_idx and audio_token_len and audio_batch_size must be provided if audio_values are provided." assert ( - len(audio_token_start_idx) == len(audio_token_len) == len(batch_size) - ), "audio_token_start_idx and audio_token_len and batch_size must have the same batch size." + len(audio_token_start_idx) + == len(audio_token_len) + == len(audio_batch_size) + ), "audio_token_start_idx and audio_token_len and audio_batch_size must have the same batch size." audio_tower_output = self.audio_tower.forward( audio_values @@ -202,19 +204,16 @@ def forward( # combine audio and text embeddings audio_ind = 0 - for i, (start, length, audio_batch_size) in enumerate( - zip(audio_token_start_idx, audio_token_len, batch_size) + for i, (start, length, batch_size) in enumerate( + zip(audio_token_start_idx, audio_token_len, audio_batch_size) ): audio = torch.cat( - [ - audio_embeds[k] - for k in range(audio_ind, audio_ind + audio_batch_size) - ], + [audio_embeds[k] for k in range(audio_ind, audio_ind + batch_size)], dim=0, ) length = min(length, audio.shape[0]) inputs_embeds[i, start : start + length] = audio[:length] - audio_ind += audio_batch_size + audio_ind += batch_size lm_output = self.language_model.forward( inputs_embeds=inputs_embeds, @@ -249,7 +248,7 @@ def prepare_inputs_for_generation( audio_values: Optional[torch.FloatTensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, - batch_size: Optional[torch.Tensor] = None, + audio_batch_size: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -278,7 +277,7 @@ def prepare_inputs_for_generation( audio_token_start_idx - prefill_start_idx ) model_input["audio_token_len"] = audio_token_len - model_input["batch_size"] = batch_size + model_input["audio_batch_size"] = audio_batch_size return model_input diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index b1a1d207..ab518a78 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -193,7 +193,7 @@ def __call__( data["audio_values"] = torch.cat(audio_values_chunks) num_audio_chunks = data["audio_values"].shape[0] - data["batch_size"] = [num_audio_chunks] + data["audio_batch_size"] = [num_audio_chunks] if text is not None: assert isinstance(