From 27c61da5fd822ed79532b8efddf28cf055358665 Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Mon, 16 Sep 2024 17:12:29 -0400 Subject: [PATCH 1/4] Update --- mcloud_eval.yaml | 22 + mcloud.yaml => mcloud_train.yaml | 10 +- ultravox/data/dataset_config.py | 23 - ultravox/data/datasets.py | 783 ++++++++---------- ultravox/data/datasets_test.py | 23 +- ultravox/evaluation/configs/eval_config.yaml | 103 +++ .../evaluation/configs/eval_config_2k.yaml | 110 +++ .../evaluation/configs/eval_config_tiny.yaml | 40 + ultravox/evaluation/eval.py | 134 ++- ultravox/evaluation/eval_helpers.py | 35 + ultravox/evaluation/eval_types.py | 21 +- ultravox/evaluation/gpt_eval_conv.py | 4 +- ultravox/evaluation/gpt_eval_test.py | 4 +- ultravox/evaluation/string_based.py | 39 - ultravox/evaluation/string_evals.py | 62 ++ ultravox/evaluation/wer.py | 33 - ultravox/inference/infer.py | 40 +- ultravox/inference/ultravox_infer.py | 10 +- ultravox/inference/utils.py | 26 - ultravox/tools/data_tool.py | 8 +- ultravox/tools/eval_tool.py | 142 +++- ultravox/tools/infer_tool.py | 255 ------ ultravox/tools/push_to_hub.py | 5 +- ultravox/tools/score_tool.py | 51 ++ ultravox/training/config_base.py | 200 +++-- ultravox/training/configs/asr_llama.yaml | 3 - ultravox/training/configs/asr_tinyllama.yaml | 4 - .../training/configs/asr_tinyllama_100s.yaml | 7 - ultravox/training/configs/llama3_whisper.yaml | 6 - .../training/configs/llama3_whisper_kd.yaml | 39 - ultravox/training/configs/llama_whisper.yaml | 102 +++ ultravox/training/configs/meta_config.yaml | 192 ++++- ultravox/training/configs/release_config.yaml | 31 +- ultravox/training/configs/stage2_lora.yaml | 23 - .../training/configs/tinyllama_whisper.yaml | 102 +++ ultravox/training/ddp_utils.py | 21 +- ultravox/training/evaluation.py | 186 ----- ultravox/training/train.py | 229 ++--- ultravox/utils/device_helpers.py | 49 ++ ultravox/utils/string_helpers.py | 59 ++ 40 files changed, 1818 insertions(+), 1418 deletions(-) create mode 100644 mcloud_eval.yaml rename mcloud.yaml => mcloud_train.yaml (70%) delete mode 100644 ultravox/data/dataset_config.py create mode 100644 ultravox/evaluation/configs/eval_config.yaml create mode 100644 ultravox/evaluation/configs/eval_config_2k.yaml create mode 100644 ultravox/evaluation/configs/eval_config_tiny.yaml create mode 100644 ultravox/evaluation/eval_helpers.py delete mode 100644 ultravox/evaluation/string_based.py create mode 100644 ultravox/evaluation/string_evals.py delete mode 100644 ultravox/evaluation/wer.py delete mode 100644 ultravox/inference/utils.py delete mode 100644 ultravox/tools/infer_tool.py create mode 100644 ultravox/tools/score_tool.py delete mode 100644 ultravox/training/configs/asr_llama.yaml delete mode 100644 ultravox/training/configs/asr_tinyllama.yaml delete mode 100644 ultravox/training/configs/asr_tinyllama_100s.yaml delete mode 100644 ultravox/training/configs/llama3_whisper.yaml delete mode 100644 ultravox/training/configs/llama3_whisper_kd.yaml create mode 100644 ultravox/training/configs/llama_whisper.yaml delete mode 100644 ultravox/training/configs/stage2_lora.yaml create mode 100644 ultravox/training/configs/tinyllama_whisper.yaml delete mode 100644 ultravox/training/evaluation.py create mode 100644 ultravox/utils/device_helpers.py create mode 100644 ultravox/utils/string_helpers.py diff --git a/mcloud_eval.yaml b/mcloud_eval.yaml new file mode 100644 index 00000000..c3fbf542 --- /dev/null +++ b/mcloud_eval.yaml @@ -0,0 +1,22 @@ +# Ultravox training configuration +name: ultravox +image: mosaicml/composer:latest +compute: + gpus: 8 + cluster: r14z3p1 +integrations: + - integration_type: git_repo + git_repo: fixie-ai/ultravox-expts + git_branch: $UV_BRANCH + pip_install: poetry==1.7.1 +scheduling: + max_duration: 2 # 2 hours max for jobs to avoid hanging jobs +command: >- + cd ultravox-expts && + poetry install --no-dev && + poetry run python -m ultravox.training.helpers.prefetch_weights $EVAL_ARGS && + poetry run torchrun --nproc_per_node=8 -m ultravox.tools.eval_tool $EVAL_ARGS --exp_name $UV_BRANCH +env_variables: + MLFLOW_TRACKING_URI: databricks + UV_BRANCH: zhuang.2024-08-12-ultravox.merge_p1 + EVAL_ARGS: --config_path ultravox/evaluation/configs/eval_config_2k.yaml diff --git a/mcloud.yaml b/mcloud_train.yaml similarity index 70% rename from mcloud.yaml rename to mcloud_train.yaml index 91e089c1..53fb24d0 100644 --- a/mcloud.yaml +++ b/mcloud_train.yaml @@ -6,17 +6,17 @@ compute: cluster: r14z3p1 integrations: - integration_type: git_repo - git_repo: fixie-ai/ultravox + git_repo: fixie-ai/ultravox-expts git_branch: $UV_BRANCH pip_install: poetry==1.7.1 scheduling: max_duration: 6 # 6 hours max for jobs to avoid hanging jobs command: >- - cd ultravox && + cd ultravox-expts && poetry install --no-dev && poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && - poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS + poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS --exp_name $UV_BRANCH env_variables: MLFLOW_TRACKING_URI: databricks - UV_BRANCH: main - TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml + UV_BRANCH: zhuang.2024-08-12-ultravox.merge_p1 + TRAIN_ARGS: --config_path ultravox/training/configs/expt_config.yaml \ No newline at end of file diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py deleted file mode 100644 index ea1f0a60..00000000 --- a/ultravox/data/dataset_config.py +++ /dev/null @@ -1,23 +0,0 @@ -import dataclasses -from typing import List, Optional - -from pydantic import BaseModel - - -class DataDictConfig(BaseModel): - # Path to the dataset, or huggingface dataset id - path: str - # Name of the dataset, or huggingface dataset config/subset - name: Optional[str] = None - splits: List[str] = dataclasses.field(default_factory=list) - num_samples: Optional[int] = None - total_samples: int = 1 - weight: float = 1.0 - streaming: bool = True - user_template: str = "<|audio|>" - assistant_template: str = "{{text}}" - transcript_template: str = "{{text}}" - - def __post_init__(self): - if not self.splits: - raise ValueError("At least one split must be provided") diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 1ea40c95..93e9951c 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1,13 +1,13 @@ import abc import base64 import dataclasses -import enum import io import itertools import logging import os import tempfile import warnings +from contextlib import closing from enum import Enum from typing import Any, Callable, Dict, List, Optional, Sequence @@ -21,17 +21,26 @@ import torch import torch.nn.functional as F import transformers +from pydantic import BaseModel from torch.utils import data -from ultravox.data import dataset_config from ultravox.data import text_proc +from ultravox.evaluation.eval_types import EvalConfig +from ultravox.utils import device_helpers +from ultravox.utils import string_helpers SAMPLE_RATE = 16000 -TRANSCRIBE_INPUT_TASK = "transcribe_input" -TRANSCRIBE_OUTPUT_TASK = "transcribe_output" -ANSWER_TASK = "answer" - +TRANSLATE_PROMPTS = [ + "Translate the following into {target}, without any explanation: <|audio|>", + "Translate the following into {target} language, no explanation needed: <|audio|>", + "Please convert the following into {target}. Be concise.\n<|audio|>", + "Could you translate this to {target} language? No commentary necessary.\n<|audio|>", + "Translate the text below to {target}.\n<|audio|>", + "Translate the subsequent text into {target} language. <|audio|>", + "Can you translate this into the {target} language?\n<|audio|>", + "Transform the following to {target}: <|audio|>", +] TRANSCRIBE_PROMPTS = [ # from Gazelle "Transcribe\n<|audio|>", @@ -67,16 +76,118 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json" os.environ["GOOGLE_CLOUD_PROJECT"] = "fixie-training" - # Silence the spurious warnings coming from the MosaicML streaming library. logging.getLogger("streaming.base.dataset").setLevel(logging.ERROR) +# Global arguments for voice datasets. @dataclasses.dataclass -class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): - # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel - include_alt_fields: bool = False +class VoiceDatasetArgs: + """Global arguments for voice datasets.""" + + batch_size: int = 4 + """Batch size for train, eval, or validation.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate. If None, use the model's max context length.""" + temperature: float = 0.0 + """Temperature for sampling.""" + include_context: bool = True + """Whether to include additional textual context from the dataset to the prompt.""" + max_context_length: int = 1500 + """Maximum length of context to include in the prompt. Otherwise, skip the sample.""" + shuffle: bool = False + """Whether to shuffle the dataset.""" + shuffle_seed: int = 42 + """Seed for shuffling the dataset.""" + max_audio_duration_secs: Optional[float] = None + """Whether to skip samples with audio longer than this duration.""" + num_prompts: int = 1 + """If `prompt` is not set, the number of canned prompts to use.""" + use_mds: bool = False + """Whether to load the dataset from GCP (using MDS) or Hugging Face.""" + mds_batch_size: int = 32 + """Batch size for MDS.""" + +class DatasetConfig(BaseModel): + type: str = "generic" + """Type of dataset class, including "generic" or predefined types like "anyinstruct", "boolq", etc.""" + alias: str = "" + """Alias of the dataset, used for referencing the dataset; constructed automatically if not provided and automatically normalized for use in filenames""" + path: str = "" + """Directory of the dataset, or huggingface dataset name; must be set for "generic" datasets. If not set, it is automatically inferred for predefined dataset types.""" + name: Optional[str] = None + """Name of the dataset, or huggingface dataset config/subset name""" + splits: List[str] = dataclasses.field(default_factory=list) + """List of splits to use, e.g. ["train", "validation"]""" + active_samples: Optional[int] = None + """Number of active samples to use, or None for all samples""" + num_samples: int = -1 + """Total number of samples from this dataset, must be set for computing the size of the dataset""" + weight: float = 1.0 + """Weight of the dataset, used for weighting the dataset in the training loop""" + streaming: bool = True + """Whether to stream the dataset""" + num_prompts: Optional[int] = None + """If predefined prompts are used, the number of prompts to use; overrides the global default""" + user_template: str = "<|audio|>" + """Template for the user's message""" + user_template_args: Dict[str, str] = {} + """Optional arguments (e.g., target language) for the user template""" + assistant_template: str = "{{text}}" + """Template for the assistant's message""" + transcript_template: str = "{{text}}" + """Template for the transcript""" + audio_field: Optional[str] = "audio" + """Field in the dataset that contains the audio, use None if the dataset does not contain audio""" + base_audio_columns: List[str] = dataclasses.field(default_factory=list) + """Base audio columns to use for the dataset; audio_field is included by default""" + eval_config: Optional[EvalConfig] = None + """Evaluation configuration: metric and arguments""" + use_mds: bool = False + """Set to True to load the dataset from GCP (using MDS) instead of Hugging Face""" + mds_batch_size: int = 32 + """Batch size for MDS""" + + class Config: + extra = "forbid" + # do not allow undefined parameters + + def model_post_init(self, __context: Any) -> None: + if not self.splits: + raise ValueError("At least one split must be provided") + + # Construct the alias if not provided + if self.alias is None: + alias: List[str] = [] + if self.type != "generic": + alias.append(self.type) + if self.path: + alias.append(self.path) + if self.name: + alias.append(self.name) + if self.splits: + alias.append(":".join(self.splits)) + self.alias = "/".join(alias) + sanitized_alias = string_helpers.sanitize_name(self.alias) + if sanitized_alias != self.alias: + print( + f"Alias '{self.alias}' sanitized to '{sanitized_alias}' for use as a valid filename" + ) + self.alias = sanitized_alias + + if self.num_samples < 0: + raise ValueError( + "num_samples must be non-negative for determing dataset size in all cases (streaming or not)" + ) + + # Ensure the audio field is always included + if self.audio_field is not None and self.audio_field not in self.base_audio_columns: + self.base_audio_columns.append(self.audio_field) + + +@dataclasses.dataclass +class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): def __call__(self, features, *args, **kwargs): audio_values = [f.pop("audio_values", None) for f in features] if self.include_alt_fields: @@ -213,42 +324,6 @@ def add_past_messages(self, past_messages: List[Dict[str, str]]): """For evaluations, the known transcript of the audio.""" -class DatasetSplit(str, enum.Enum): - TRAIN = "train" - VALIDATION = "validation" - - -@dataclasses.dataclass -class VoiceDatasetArgs: - data_dir: Optional[str] = None - prompt: Optional[str] = None - """A specific prompt to use for the dataset.""" - num_prompts: int = 1 - """If `prompt` is not set, the number of canned prompts to use.""" - include_audio: bool = True - """Whether to include audio in the samples.""" - include_context: bool = True - """Whether to include additional textual context from the dataset to the prompt.""" - max_context_length: int = 1500 - """Maximum length of context to include in the prompt. Otherwise, skip the sample.""" - shuffle: bool = False - """Whether to shuffle the dataset.""" - shuffle_seed: int = 42 - """Seed for shuffling the dataset.""" - max_audio_duration_secs: Optional[float] = None - """Whether to skip samples with audio longer than this duration.""" - use_mds: bool = False - """Whether to load the dataset from GCP (using MDS) or Hugging Face.""" - mds_batch_size: int = 32 - """Batch size for MDS.""" - split: DatasetSplit = DatasetSplit.TRAIN - """Which split of the dataset to use.""" - - def __post_init__(self): - if isinstance(self.split, str): - self.split = DatasetSplit(self.split.lower()) - - def _get_messages( *turns: str, sys_prompt: Optional[str] = None, assistant_last: bool = True ) -> List[Dict[str, str]]: @@ -289,24 +364,31 @@ class VoiceDataset(SizedIterableDataset): Wraps a Hugging Face dataset or MDS-formatted dataset from GCP. """ - BASE_AUDIO_COLUMNS = ["audio"] - - def __init__(self, args: VoiceDatasetArgs) -> None: + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: super().__init__() self._args = args - self._session: Optional[requests.Session] = None + self._config = config self._rng = np.random.default_rng(self._args.shuffle_seed) - self._weight = 1.0 # the default weight for the dataset + if device_helpers.get_local_rank() == 0: + logging.info( + f"Created VoiceDataset with config:\n{self._config.model_dump_json(indent=2)}" + ) - def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: + # _dataset and _length are initialized in _init_dataset + self._dataset = None + self._length = 0 + + def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None: self._dataset = dataset - # Only required when using epochs when training dataset. - self._estimated_length = estimated_length + self._length = num_samples @property def weight(self) -> float: - return self._weight + return self._config.weight + def __len__(self): + return self._length + def _load_audio_dataset( self, path: str, @@ -316,7 +398,6 @@ def _load_audio_dataset( shuffle: Optional[bool] = None, streaming: bool = True, ) -> data.Dataset: - logging.info(f"Loading dataset {path} {name} {split} {shuffle} {streaming}") if shuffle is None: shuffle = self._args.shuffle if self._args.use_mds: @@ -337,10 +418,12 @@ def _load_audio_dataset( shuffle_seed=self._args.shuffle_seed, ) else: + # HF datasets sometimes fails to download due to network issues, so retry a few times. dataset = datasets.load_dataset( - path, name, split=split, trust_remote_code=True, streaming=streaming + path, name, split=split, trust_remote_code=True, streaming=streaming, + download_config=datasets.DownloadConfig(max_retries=10) ) - for column_name in self.BASE_AUDIO_COLUMNS: + for column_name in self._config.base_audio_columns: dataset = dataset.cast_column( column_name, datasets.Audio(sampling_rate=SAMPLE_RATE) ) @@ -352,30 +435,43 @@ def __iter__(self): actual_length = 0 for _, row in enumerate(self._dataset): sample = self._get_sample(row) - if sample is not None: - if ( - self._args.max_audio_duration_secs is None - or sample.audio is None - or sample.audio.shape[-1] / SAMPLE_RATE - <= self._args.max_audio_duration_secs - ): + if sample is None: + raise ValueError( + f"Sample is None in dataset {self._config.alias} for row {row}" + ) + else: + if self._config.audio_field is not None: + # If audio_field is set, make sure the sample has audio data. + if sample.audio is None: + raise ValueError( + f"Audio field ({self._config.audio_field}) is None in dataset {self._config.alias} for sample {sample}" + ) + if sample.audio.shape[-1] == 0: + raise ValueError( + f"Audio length is 0 in dataset {self._config.alias} for sample {sample}" + ) + if ( + self._args.max_audio_duration_secs is not None + and sample.audio.shape[-1] / SAMPLE_RATE + > self._args.max_audio_duration_secs + ): + warnings.warn( + f"Audio length ({sample.audio.shape[-1] / SAMPLE_RATE}s) exceeds max audio duration ({self._args.max_audio_duration_secs}s) in dataset {self._config.alias}, skipping sample." + ) + else: + yield sample + else: yield sample actual_length += 1 - # If len(dataset) == 0 most likely the dataset is a validation dataset, - # or the training is using max_steps instead of num_epochs. - if actual_length > len(self) and len(self) > 1: + if actual_length == len(self) + 1: warnings.warn( - f"The estimated length {self._estimated_length} has been exceeded for type {type(self._dataset)}. Make sure to update." + f"The actual number of samples ({actual_length}) has exceeded the presumed length ({self._length}) for dataset {self._config.alias}. Make sure to update." ) - - if actual_length != len(self) and len(self) > 1: + if actual_length != len(self): warnings.warn( - f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update." + f"Mismatch between actual length ({actual_length}) and presumed length ({self._length}) for dataset {self._config.alias}. Make sure to update." ) - def __len__(self): - return self._estimated_length - @abc.abstractmethod def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: """ @@ -384,35 +480,34 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: """ def _choice(self, prompts: List[str]) -> str: - return self._rng.choice(prompts[: self._args.num_prompts]) + # If num_prompts is not set for a dataset, use the global default. + return self._rng.choice( + prompts[: self._config.num_prompts or self._args.num_prompts] + ) def _get_answer_prompt(self) -> str: - if self._args.prompt: - return self._args.prompt return self._choice(ANSWER_PROMPTS) def _get_transcribe_prompt(self) -> str: - if self._args.prompt: - return self._args.prompt return self._choice(TRANSCRIBE_PROMPTS) def _get_answer_messages( self, question: str, answer: str, context: Optional[str] = None ) -> List[Dict[str, str]]: - prompt = self._get_answer_prompt() if self._args.include_audio else question + prompt = self._get_answer_prompt() if self._config.audio_field else question user_content = f"{context}\n\n{prompt}" if context else prompt return _get_messages(user_content, answer) def _get_transcribe_messages(self, text: str) -> List[Dict[str, str]]: prompt = self._get_transcribe_prompt() - if not self._args.include_audio: + if self._config.audio_field is None: prompt = prompt.replace("<|audio|>", text) return _get_messages(prompt, text) def _get_audio( self, row: transformers.BatchFeature, column_name: str = "audio" ) -> np.ndarray: - if column_name not in self.BASE_AUDIO_COLUMNS: + if column_name not in self._config.base_audio_columns: raise ValueError( f"Unknown audio column: {column_name}. This is likely a bug and the audio might not be resampled to {SAMPLE_RATE} Hz." ) @@ -430,18 +525,30 @@ def _get_audio( assert sampling_rate == SAMPLE_RATE return audio - def _load_audio(self, base_url: str, folder: str, filename: str) -> np.ndarray: - if self._args.data_dir: - audio_path = f"{self._args.data_dir}/{folder}/{filename}" - audio = audio_from_file(audio_path) - else: + def _load_audio( + self, base_url: Optional[str], data_dir: Optional[str], filename: str + ) -> np.ndarray: + if base_url is not None: url = f"{base_url}/{filename}" # hack for GCS bucket naming - if self._session is None: - self._session = requests.Session() - response = self._session.get(url) - response.raise_for_status() - audio = audio_from_buf(response.content) - return audio + try: + with closing(requests.Session()) as session: + response = session.get(url) + response.raise_for_status() + return audio_from_buf(response.content) + except requests.RequestException as e: + raise ValueError( + f"Failed to load audio from URL: {url}. Error: {str(e)}" + ) + elif data_dir is not None: + audio_path = os.path.join(data_dir, filename) + try: + return audio_from_file(audio_path) + except IOError as e: + raise ValueError( + f"Failed to load audio file: {audio_path}. Error: {str(e)}" + ) + else: + raise ValueError("Either base_url or data_dir must be provided") def _get_transcribe_sample( self, @@ -465,37 +572,82 @@ def _make_sample( audio: np.ndarray, audio_transcript: Optional[str] = None, ) -> VoiceSample: - if not self._args.include_audio: + if self._config.audio_field is None: return VoiceSample(messages) return VoiceSample(messages, audio, audio_transcript=audio_transcript) -class LibriSpeechDummyDataset(VoiceDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "hf-internal-testing/librispeech_asr_dummy", - "clean", - split="validation", - streaming=False, # not supported by the dummy dataset +class GenericVoiceDataset(VoiceDataset): + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + super().__init__(args, config) + if not config.path: + raise ValueError( + "path is required for GenericVoiceDataset, manually set in dataset config or automatically inferred from predefined dataset type" + ) + num_samples = config.num_samples + dataset = datasets.concatenate_datasets( + [ + self._load_audio_dataset( + config.path, + name=config.name, + split=s, + streaming=config.streaming, + shuffle=False, + ) + for s in config.splits + ] ) - self._init_dataset(dataset, 73) + # shuffling is only supported on huggingface datasets for now, not MDS + if self._args.shuffle: + dataset = dataset.shuffle(seed=self._args.shuffle_seed) - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) + if config.active_samples: + dataset = Range( + SizedIterableDataset(dataset, config.num_samples), config.active_samples + ) + num_samples = config.active_samples + + # super().__init__(args, config, dataset, num_samples) + super()._init_dataset(dataset, num_samples) + + def _get_sample(self, row) -> Optional[VoiceSample]: + try: + user_content = jinja2.Template( + self._config.user_template, undefined=jinja2.StrictUndefined + ).render( + **row, + text_proc=text_proc, + dataset=self, + **self._config.user_template_args, + ) + assistant_content = jinja2.Template( + self._config.assistant_template, undefined=jinja2.StrictUndefined + ).render(**row, text_proc=text_proc, dataset=self) + transcript = jinja2.Template( + self._config.transcript_template, undefined=jinja2.StrictUndefined + ).render(**row, text_proc=text_proc, dataset=self) + except jinja2.TemplateError as e: + print(f"Error rendering template: {e}") + print(f"user_template: {self._config.user_template}") + print(f"assistant_template: {self._config.assistant_template}") + print(f"transcript_template: {self._config.transcript_template}") + print(f"sample keys: {list(row.keys())}") + raise ValueError( + f"Template rendering failed. Make sure all keys in the template exist in the sample." + ) from e + + return self._make_sample( + _get_messages(user_content, assistant_content), + self._get_audio(row, self._config.audio_field), + audio_transcript=transcript, + ) # Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training. class EmptyDataset(SizedIterableDataset): - def __init__(self, estimated_length: int = 1) -> None: - self._estimated_length = estimated_length - def __iter__(self): return iter([]) - def __len__(self): - return self._estimated_length - class AnyInstructDataset(VoiceDataset): """ @@ -506,9 +658,9 @@ class AnyInstructDataset(VoiceDataset): ]} """ - def __init__(self, args: VoiceDatasetArgs) -> None: + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: # TODO(juberti): convert to MDS - super().__init__(args) + super().__init__(args, config) dataset = datasets.load_dataset( "json", "anyinstruct", @@ -518,12 +670,16 @@ def __init__(self, args: VoiceDatasetArgs) -> None: dataset = dataset.train_test_split( test_size=0.01, seed=args.shuffle_seed, shuffle=True ) - dataset = dataset["train" if args.split == DatasetSplit.TRAIN else "test"] + if len(self._config.splits) > 1: + raise ValueError( + "AnyInstructDataset is hardcoded to only support one split: train or test" + ) + dataset = dataset["train" if self._config.splits[0] == "train" else "test"] # TODO: make num_shards configurable if need be dataset = dataset.to_iterable_dataset(num_shards=16) if args.shuffle: dataset = dataset.shuffle(seed=args.shuffle_seed) - self._init_dataset(dataset) + self._init_dataset(dataset, len(dataset)) def _load_anyinstruct_audio(self, filename: str): return super()._load_audio( @@ -534,8 +690,8 @@ def _load_anyinstruct_audio(self, filename: str): class AnyInstructAnswerDataset(AnyInstructDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + super().__init__(args, config) def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: chat = row["chat"] @@ -547,8 +703,8 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: class AnyInstructInputDataset(AnyInstructDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + super().__init__(args, config) def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: audio_transcript = row["chat"][0]["message"] @@ -560,8 +716,8 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: class AnyInstructOutputDataset(AnyInstructDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + super().__init__(args, config) def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: audio_transcript = row["chat"][1]["message"] @@ -572,13 +728,10 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: ) -class BoolQDataset(VoiceDataset): - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/boolq-audio", split=args.split.value - ) - self._init_dataset(dataset) +class BoolQDataset(GenericVoiceDataset): + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + config.path = config.path or "fixie-ai/boolq-audio" + super().__init__(args, config) def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: question = row["question"] @@ -633,16 +786,12 @@ def _get_query_prompt(self, question_str: str, context: str) -> Optional[str]: # Skip samples with long context return None - if self._args.prompt: - prompt = self._args.prompt - else: - prompt = self._choice(self.PROMPT_SUFFIXES) - + prompt = self._config.user_template or self._choice(self.PROMPT_SUFFIXES) # Separate either with 1 or 2 newlines separator = self._choice(self.SEPARATORS) query_prompt = self._choice(self.QUERY_PREFIX) - question = "<|audio|>" if self._args.include_audio else question_str + question = "<|audio|>" if self._config.audio_field else question_str prompt = f"{query_prompt}{question}{separator}{prompt}" if self._args.include_context: @@ -707,12 +856,9 @@ class HeySQuADHumanDataset(QAVoiceDatasetMixin): This dataset is the human-spoken version of HeySQuAD. """ - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/HeySQuAD_human", split=args.split.value - ) - self._init_dataset(dataset) + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + config.path = config.path or "fixie-ai/HeySQuAD_human" + super().__init__(args, config) def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: """ @@ -753,14 +899,12 @@ class SlueSQA5Dataset(QAVoiceDatasetMixin): Splits: train, validation, test, verified_test """ - BASE_AUDIO_COLUMNS = ["question_audio", "document_audio"] - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "asapp/slue-phase-2", "sqa5", split=args.split.value - ) - self._init_dataset(dataset) + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + config.path = config.path or "asapp/slue-phase-2" + config.name = config.name or "sqa5" + if len(config.base_audio_columns) != 2: + config.base_audio_columns = ["question_audio", "document_audio"] + super().__init__(args, config) def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: """ @@ -784,209 +928,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: ) -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class LibriSpeechDataset(VoiceDataset): - """ - LibriSpeech is a corpus of approximately 1000 hours of 16kHz read - English speech. The data is derived from read audiobooks from the - LibriVox project. A simple automatic procedure was used to select - the audio in the first two sets to be, on average, of higher - recording quality and with accents closer to US English. - https://huggingface.co/datasets/librispeech_asr - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - # TODO(juberti): convert to MDS, in a way that preserves the same - # concatenation of the three splits. MDS can interleave but not - # concatenate, it seems. - super().__init__(args) - ds: Any - if args.split == DatasetSplit.VALIDATION: - ds = self._load_audio_dataset("librispeech_asr", split="validation.clean") - else: - splits = ["train.clean.100", "train.clean.360", "train.other.500"] - ds = datasets.concatenate_datasets( - [ - self._load_audio_dataset("librispeech_asr", split=s, shuffle=False) - for s in splits - ] - ) - if self._args.shuffle: - ds = ds.shuffle(seed=self._args.shuffle_seed) - self._init_dataset(ds) - - def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class GigaSpeechDataset(VoiceDataset): - """ - GigaSpeech is an evolving, multi-domain English speech recognition corpus - with 10,000 hours of high quality labeled audio suitable for supervised training. - "s" split is 250 hours. Non-commercial use only. - https://huggingface.co/datasets/speechcolab/gigaspeech - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "speechcolab/gigaspeech", "xl", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class VoxPopuliDataset(VoiceDataset): - """ - VoxPopuli is a large-scale multilingual speech corpus for representation learning, - semi-supervised learning and interpretation. - "en" split is 543 hours. - https://huggingface.co/datasets/facebook/voxpopuli - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "facebook/voxpopuli", "en", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="raw_text") - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class CommonVoiceDataset(VoiceDataset): - """ - The Common Voice dataset consists of a unique MP3 and corresponding text file - https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1 - Dataset({ - features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'], - num_rows: 1090061 - }) - NOTE: requires HF login - """ - - def __init__(self, args: VoiceDatasetArgs, lang: str = "en") -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "mozilla-foundation/common_voice_16_1", lang, split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="sentence") - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class CoVoST2Dataset(VoiceDataset): - """ - CoVoST 2 is a large-scale multilingual speech translation corpus covering translations from 21 languages into English - and from English into 15 languages. The dataset is created using Mozilla's open-source Common Voice 4 database of - crowdsourced voice recordings. There are 2,900 hours of speech represented in the corpus. - - The original Hugging Face dataset link: https://huggingface.co/datasets/facebook/covost2 - Since this dataset requires audio files to be downloaded separately, a new dataset is created with the audio files: - https://huggingface.co/datasets/fixie-ai/covost2 - - Due to the scale of the dataset and the audio files being repeated, only a portion of the dataset was converted. - See [this issue](https://github.com/fixie-ai/ultravox/issues/50) for more information. - - Supported subsets (En -> X): - 'en_de', 'en_tr', 'en_fa', 'en_sv-SE', 'en_mn', 'en_zh-CN', 'en_cy', - 'en_ca', 'en_sl', 'en_et', 'en_id', 'en_ar', 'en_ta', 'en_lv', 'en_ja' - Supported subsets (X -> En): - 'fr_en', 'zh-CN_en', 'es_en' - """ - - CODE_TO_LANG = { - "en": "English", - "de": "German", - "tr": "Turkish", - "fa": "Persian", - "sv-SE": "Swedish", - "mn": "Mongolian", - "zh-CN": "Chinese", - "cy": "Welsh", - "ca": "Catalan", - "sl": "Slovenian", - "et": "Estonian", - "id": "Indonesian", - "ar": "Arabic", - "ta": "Tamil", - "lv": "Latvian", - "ja": "Japanese", - "fr": "French", - "es": "Spanish", - } - - # We currently don't use this dataset for training, so mainly the first prompt it ever used. - # The "no explanation" part is important, specially for evaluations, but it's not repeated - # in all prompts to avoid being too repetitive in training. - TRANSLATE_PROMPTS = [ - "Translate the following into {target}, without any explanation: <|audio|>", - "Translate the following into {target} language, no explanation needed: <|audio|>", - "Please convert the following into {target}. Be concise.\n<|audio|>", - "Could you translate this to {target} language? No commentary necessary.\n<|audio|>", - "Translate the text below to {target}.\n<|audio|>", - "Translate the subsequent text into {target} language. <|audio|>", - "Can you translate this into the {target} language?\n<|audio|>", - "Transform the following to {target}: <|audio|>", - ] - - def __init__(self, args: VoiceDatasetArgs, subset: str) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/covost2", subset, split=args.split.value - ) - langs = subset.split("_") - assert len(langs) == 2, f"Invalid subset: {subset}" - self.source_lang = self.CODE_TO_LANG[langs[0]] - self.target_lang = self.CODE_TO_LANG[langs[1]] - self._init_dataset(dataset) - - def _get_sample(self, row) -> VoiceSample: - prompt = self._choice(self.TRANSLATE_PROMPTS).format(target=self.target_lang) - - transcript = row["sentence"] - translation = row["translation"] - if not self._args.include_audio: - prompt = prompt.replace("<|audio|>", transcript) - - return self._make_sample( - _get_messages(prompt, translation), - self._get_audio(row), - audio_transcript=transcript, - ) - - -# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future. -class PeopleSpeechDataset(VoiceDataset): - """ - The People's Speech Dataset is among the world's largest English speech - recognition corpus. It includes 30,000+ hours of transcribed speech in - English languages with a diverse set of speakers. - https://huggingface.co/datasets/MLCommons/peoples_speech - """ - - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "MLCommons/peoples_speech", "clean", split=args.split.value - ) - self._init_dataset(dataset) - - def _get_sample(self, row) -> Optional[VoiceSample]: - return self._get_transcribe_sample(row, tcol="text") - - -class SodaDataset(VoiceDataset): - BASE_AUDIO_COLUMNS = ["audio_second_last_turn"] - +class SodaDataset(GenericVoiceDataset): SYS_PROMPTS = [ "Follow the flow of the conversation and respond just like a human would in the same situation.", "Engage in the conversation naturally, responding as a human would.", @@ -1001,12 +943,13 @@ class SodaDataset(VoiceDataset): "Maintain the flow of the chat and answer just as a person would.", ] - def __init__(self, args: VoiceDatasetArgs) -> None: - super().__init__(args) - dataset = self._load_audio_dataset( - "fixie-ai/soda-audio", split=args.split.value - ) - self._init_dataset(dataset) + def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: + config.path = config.path or "fixie-ai/soda-audio" + if config.audio_field != "audio_second_last_turn": + warnings.warn( + f"SODA dataset expects audio field 'audio_second_last_turn' but found '{config.audio_field}'. Please make sure this is intentional." + ) + super().__init__(args, config) def _get_sample(self, row) -> VoiceSample: turns = row["dialogue"] @@ -1018,78 +961,21 @@ def _get_sample(self, row) -> VoiceSample: messages = _get_messages(*turns[:-1], sys_prompt=sys_prompt) messages[-1]["content"] = row["alt_last_turn"] - if self._args.include_audio: + if self._config.audio_field: messages[-2]["content"] = "<|audio|>" return self._make_sample( messages, - audio=self._get_audio(row, "audio_second_last_turn"), + audio=self._get_audio(row, self._config.audio_field), audio_transcript=turns[-2], ) -class GenericVoiceDataset(VoiceDataset): - def __init__( - self, args: VoiceDatasetArgs, config: dataset_config.DataDictConfig - ) -> None: - super().__init__(args) - dataset = datasets.concatenate_datasets( - [ - self._load_audio_dataset( - config.path, - name=config.name, - split=s, - streaming=config.streaming, - shuffle=False, - ) - for s in config.splits - ] - ) - # shuffling is only supported on huggingface datasets for now, not MDS - if self._args.shuffle: - dataset = dataset.shuffle(seed=self._args.shuffle_seed) - - if config.num_samples: - dataset = Range(dataset, config.num_samples, config.total_samples) - - self._weight = config.weight - - self.user_template = config.user_template - self.assistant_template = config.assistant_template - self.transcript_template = config.transcript_template - - super()._init_dataset(dataset, config.total_samples) - - def _get_sample(self, row) -> VoiceSample: - try: - user_content = jinja2.Template( - self.user_template, undefined=jinja2.StrictUndefined - ).render(**row, text_proc=text_proc, dataset=self) - assistant_content = jinja2.Template( - self.assistant_template, undefined=jinja2.StrictUndefined - ).render(**row, text_proc=text_proc, dataset=self) - transcript = jinja2.Template( - self.transcript_template, undefined=jinja2.StrictUndefined - ).render(**row, text_proc=text_proc, dataset=self) - except jinja2.TemplateError as e: - print(f"Error rendering template: {e}") - print(f"user_template: {self.user_template}") - print(f"assistant_template: {self.assistant_template}") - print(f"transcript_template: {self.transcript_template}") - print(f"sample keys: {list(row.keys())}") - raise ValueError( - f"Template rendering failed. Make sure all keys in the template exist in the sample." - ) from e - - return self._make_sample( - _get_messages(user_content, assistant_content), - self._get_audio(row), - audio_transcript=transcript, - ) - - -def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: +def create_dataset( + args: VoiceDatasetArgs, config: DatasetConfig +) -> SizedIterableDataset: DATASET_MAP: Dict[str, Any] = { + "generic": GenericVoiceDataset, "anyinstruct": AnyInstructAnswerDataset, "anyinstruct_in": AnyInstructInputDataset, "anyinstruct_out": AnyInstructOutputDataset, @@ -1098,26 +984,15 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> SizedIterableDataset: "boolq_extended": BoolQWithExtendedAnswerDataset, "heysquad_human": HeySQuADHumanDataset, "slue_sqa5": SlueSQA5Dataset, - "gigaspeech": GigaSpeechDataset, - "librispeech": LibriSpeechDataset, - "voxpopuli": VoxPopuliDataset, - "commonvoice": CommonVoiceDataset, - "covost2": CoVoST2Dataset, - "peoplespeech": PeopleSpeechDataset, "soda": SodaDataset, - "dummy": LibriSpeechDummyDataset, } - if isinstance(name, dataset_config.DataDictConfig): - return GenericVoiceDataset(args, name) - else: - name, *ext = name.split(":") - return DATASET_MAP[name](args, *ext) + return DATASET_MAP[config.type](args, config) class StopStrategy(str, Enum): - FIRST_EXHAUSTED = "first_exhausted" - LAST_EXHAUSTED = "last_exhausted" - NEVER_STOP = "never_stop" + FIRST_EXHAUSTED = "FIRST_EXHAUSTED" + LAST_EXHAUSTED = "LAST_EXHAUSTED" + NEVER_STOP = "NEVER_STOP" class InterleaveDataset(SizedIterableDataset): @@ -1207,36 +1082,28 @@ class Range(SizedIterableDataset): def __init__( self, - dataset: data.IterableDataset, + dataset: SizedIterableDataset, num_samples: Optional[int] = None, - total_samples: Optional[int] = None, ) -> None: self._dataset = dataset - self._num_samples = num_samples - - if isinstance(self._dataset, SizedIterableDataset): - self._estimated_length = len(self._dataset) + if num_samples is None: + self._length = len(dataset) else: - if total_samples is None: + if num_samples > len(dataset): raise ValueError( - "total_samples must be provided for non-SizedIterableDataset." + f"num_samples {num_samples} is greater than the number of samples in the dataset {len(dataset)}" ) - self._estimated_length = total_samples - - if self._num_samples is not None and self._num_samples > self._estimated_length: - # Issuing a warning here instead of raising an error to accomodate for specific classes of VoiceDataset - # Once we migrate entirely to GenericVoiceDataset, we can raise an error here. - warnings.warn("num_samples is greater than total_samples.") + self._length = num_samples def __iter__(self): - for i, sample in enumerate(self._dataset): - if self._num_samples is not None and i >= self._num_samples: + count = 0 + for sample in self._dataset: + if self._length is not None and count >= self._length: break yield sample + count += 1 - def __len__(self): - return ( - self._num_samples - if self._num_samples is not None - else self._estimated_length - ) + if count != self._length: + raise ValueError( + f"Dataset exhausted after {count} samples, expected {self._length}" + ) diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index faacd3fa..6cfea7f0 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -8,7 +8,6 @@ from torch.utils import data from transformers.feature_extraction_utils import BatchFeature -from ultravox.data import dataset_config from ultravox.data import datasets @@ -51,8 +50,16 @@ def __iter__(self): class FakeTranscribeDataset(datasets.VoiceDataset): """Fake version of our VoiceDataset using a transcribe prompt.""" - def __init__(self, n: int, args: Optional[datasets.VoiceDatasetArgs] = None): - super().__init__(args or datasets.VoiceDatasetArgs()) + def __init__( + self, + n: int, + args: Optional[datasets.VoiceDatasetArgs] = None, + config: Optional[datasets.DatasetConfig] = None, + ): + super().__init__( + args or datasets.VoiceDatasetArgs(), + config or datasets.DatasetConfig(num_samples=n), + ) self._init_dataset(FakeHuggingFaceIterableDataset(n), n) @@ -66,11 +73,11 @@ class FakeGenericDataset(datasets.VoiceDataset): def __init__( self, n: int, - config: dataset_config.DataDictConfig, + config: datasets.DatasetConfig, args: Optional[datasets.VoiceDatasetArgs] = None, ): - super().__init__(args or datasets.VoiceDatasetArgs()) - self._init_dataset(FakeHuggingFaceIterableDataset(n), config.total_samples) + super().__init__(args or datasets.VoiceDatasetArgs(), config) + self._init_dataset(FakeHuggingFaceIterableDataset(n), config.num_samples) def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]: return self._get_transcribe_sample(row) @@ -324,7 +331,7 @@ def test_get_messages(): def test_voice_dataset_size(): - mock_config = dataset_config.DataDictConfig(path="mock_path", total_samples=5) + mock_config = datasets.DatasetConfig(path="mock_path", total_samples=5) ds = FakeGenericDataset(10, mock_config) assert len(ds) == 5 @@ -332,7 +339,7 @@ def test_voice_dataset_size(): with pytest.warns(UserWarning, match=pattern): list(ds) - mock_config = dataset_config.DataDictConfig(path="mock_path", total_samples=10) + mock_config = datasets.DatasetConfig(path="mock_path", total_samples=10) ds = FakeGenericDataset(5, mock_config) assert len(ds) == 10 diff --git a/ultravox/evaluation/configs/eval_config.yaml b/ultravox/evaluation/configs/eval_config.yaml new file mode 100644 index 00000000..53c866a9 --- /dev/null +++ b/ultravox/evaluation/configs/eval_config.yaml @@ -0,0 +1,103 @@ +model_path: "fixie-ai/ultravox-v0_4" +report_logs_to: ["wandb"] +output_dir: "./output" + +eval_dataset_args: + batch_size: 40 + max_tokens: 512 + +# full test sets, used for comparision with OAI evaluation to ensure consistency +eval_dataset_configs: + - alias: "covost2-en_de-test" + path: "fixie-ai/covost2" + name: "en_de" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "German" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-en_zh-test" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" + args: + tokenize: "zh" + - alias: "covost2-en_ar-test" + path: "fixie-ai/covost2" + name: "en_ar" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Arabic" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-en_ca-test" + path: "fixie-ai/covost2" + name: "en_ca" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Catalan" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-es_en-test" + path: "fixie-ai/covost2" + name: "es_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-ru_en-test" + path: "fixie-ai/covost2" + name: "ru_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-zh_en-test" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + num_samples: 15531 + eval_config: + metric: "bleu" diff --git a/ultravox/evaluation/configs/eval_config_2k.yaml b/ultravox/evaluation/configs/eval_config_2k.yaml new file mode 100644 index 00000000..b69305b5 --- /dev/null +++ b/ultravox/evaluation/configs/eval_config_2k.yaml @@ -0,0 +1,110 @@ +model_path: "fixie-ai/ultravox-v0_4" +report_logs_to: ["wandb"] +output_dir: "./output" + +eval_dataset_args: + batch_size: 40 + max_tokens: 512 + +# subset of the eval_config.yaml test sets, used during model development to score quickly +eval_dataset_configs: + - alias: "covost2-en_de-test-2k" + path: "fixie-ai/covost2" + name: "en_de" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "German" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-en_zh-test-2k" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + args: + tokenize: "zh" + - alias: "covost2-en_ar-test-2k" + path: "fixie-ai/covost2" + name: "en_ar" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Arabic" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-en_ca-test-2k" + path: "fixie-ai/covost2" + name: "en_ca" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Catalan" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-es_en-test-2k" + path: "fixie-ai/covost2" + name: "es_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-ru_en-test-2k" + path: "fixie-ai/covost2" + name: "ru_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-zh_en-test-2k" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" \ No newline at end of file diff --git a/ultravox/evaluation/configs/eval_config_tiny.yaml b/ultravox/evaluation/configs/eval_config_tiny.yaml new file mode 100644 index 00000000..963a93bd --- /dev/null +++ b/ultravox/evaluation/configs/eval_config_tiny.yaml @@ -0,0 +1,40 @@ +model_path: "fixie-ai/ultravox-v0_4" +report_logs_to: ["wandb"] +output_dir: "./output" + +eval_dataset_args: + batch_size: 4 + max_tokens: 50 + +# tiny test sets for debugging +eval_dataset_configs: + - alias: "covost2-en_zh-test-20" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 20 + num_samples: 15531 + eval_config: + metric: "bleu" + args: + tokenize: "zh" + - alias: "covost2-zh_en-test-20" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 20 + num_samples: 15531 + eval_config: + metric: "bleu" \ No newline at end of file diff --git a/ultravox/evaluation/eval.py b/ultravox/evaluation/eval.py index bc92e726..8eb16022 100644 --- a/ultravox/evaluation/eval.py +++ b/ultravox/evaluation/eval.py @@ -1,22 +1,114 @@ +import json +import logging +import os +from typing import List, Optional + +from tqdm import tqdm + +from ultravox.data import datasets +from ultravox.evaluation import eval_helpers from ultravox.evaluation import eval_types -from ultravox.evaluation import gpt_eval_boolq -from ultravox.evaluation import gpt_eval_conv -from ultravox.evaluation import gpt_eval_instruct -from ultravox.evaluation import string_based -from ultravox.evaluation import wer - -METRIC_REGISTRY = { - "asr": wer.evaluate_answer_asr, - "boolq": gpt_eval_boolq.evaluate_answer_boolq, - "instruct": gpt_eval_instruct.evaluate_answer_instruct, - "conversation": gpt_eval_conv.evaluate_conversation_response, - "exact_match_last_word": string_based.match_last_word, - "bleu": string_based.bleu, -} - - -def evaluate_answer(sample: eval_types.Sample, metric: str) -> eval_types.Result: - if metric in METRIC_REGISTRY: - return METRIC_REGISTRY[metric](sample) - else: - raise ValueError(f"Unknown metric: {metric}") +from ultravox.inference import infer +from ultravox.training import ddp_utils + +logging.basicConfig(level=logging.INFO) + + +def dataset_infer( + inference: infer.LocalInference, + dataset: datasets.SizedIterableDataset, + world_size: int = 1, + local_rank: int = 0, + batch_size: int = 1, + max_tokens: Optional[int] = None, + temperature: float = 0.0, +) -> List[eval_types.Sample]: + results: List[eval_types.Sample] = [] + if local_rank == 0: + total_batches = len(dataset) // (batch_size * world_size) + progress_bar = tqdm(total=total_batches) + for batch_input in ddp_utils.sharded_batch_iterator( + dataset, batch_size, world_size, local_rank + ): + batch_indices = [idx for idx, _ in batch_input] + batch_samples = [sample for _, sample in batch_input] + batch_references = [] + for sample in batch_samples: + assistant_message = sample.messages.pop() + if assistant_message["role"] != "assistant": + raise ValueError( + f"Expected assistant message but got: role={assistant_message['role']}, content={assistant_message['content']}" + ) + batch_references.append(assistant_message["content"]) + batch_output = inference.infer_batch( + batch_samples, + max_tokens=max_tokens, + temperature=temperature, + ) + for index, sample, output, reference in zip( + batch_indices, batch_samples, batch_output, batch_references + ): + results.append( + eval_types.Sample( + index=index, + question=sample.messages[-1]["content"], + reference=reference, + hypothesis=output.text, + ) + ) + + if local_rank == 0: + progress_bar.update(1) + if local_rank == 0: + progress_bar.close() + return results + + +def run_infer( + inference: infer.LocalInference, + dataset_args: datasets.VoiceDatasetArgs, + dataset_configs: List[datasets.DatasetConfig], + world_size: int, + local_rank: int, + output_dir: Optional[str] = None, +): + metrics = {} + output_files = [] + for dataset_config in dataset_configs: + dataset = datasets.create_dataset(dataset_args, dataset_config) + results = dataset_infer( + inference, + dataset, + world_size=world_size, + local_rank=local_rank, + batch_size=dataset_args.batch_size, + max_tokens=dataset_args.max_tokens, + temperature=dataset_args.temperature, + ) + results = ddp_utils.all_gather_list(results) + + # compute metrics, if specified, and save results only on the first process + if local_rank == 0: + # ensure same order + results.sort(key=lambda x: x.index) + dataset_alias = dataset_config.alias + if dataset_config.eval_config: + eval_result: eval_types.Result = eval_helpers.evaluate_answers( + results, dataset_config.eval_config + ) + logging.info( + f"Eval: {dataset_alias}, {dataset_config.eval_config.metric}: {eval_result.score:.2f}" + ) + + metrics[f"eval_{dataset_alias}-{dataset_config.eval_config.metric}"] = ( + eval_result.score + ) + + if output_dir: + output_file = os.path.join(output_dir, f"{dataset_alias}.json") + with open(output_file, "w") as f: + results_json = [result.to_dict() for result in results] + json.dump(results_json, f, ensure_ascii=False, indent=2) + print(f"Results saved to {output_file}") + output_files.append(output_file) + return metrics, output_files diff --git a/ultravox/evaluation/eval_helpers.py b/ultravox/evaluation/eval_helpers.py new file mode 100644 index 00000000..b410619e --- /dev/null +++ b/ultravox/evaluation/eval_helpers.py @@ -0,0 +1,35 @@ +from typing import List + +from ultravox.evaluation import eval_types +from ultravox.evaluation import gpt_eval_boolq +from ultravox.evaluation import gpt_eval_conv +from ultravox.evaluation import gpt_eval_instruct +from ultravox.evaluation import string_evals + +METRIC_REGISTRY = { + "asr": string_evals.wer, + "boolq": gpt_eval_boolq.evaluate_answer_boolq, + "instruct": gpt_eval_instruct.evaluate_answer_instruct, + "conversation": gpt_eval_conv.evaluate_conversation_response, + "exact_match_last_word": string_evals.match_last_word, +} + +CORPUS_METRIC_REGISTRY = {"bleu": string_evals.bleu} + + +def evaluate_answer(sample: eval_types.Sample, metric: str) -> eval_types.Result: + if metric in METRIC_REGISTRY: + return METRIC_REGISTRY[metric](sample) + else: + raise ValueError(f"Unknown metric: {metric}") + + +def evaluate_answers( + samples: List[eval_types.Sample], metric_config: eval_types.EvalConfig +) -> eval_types.Result: + if metric_config.metric in CORPUS_METRIC_REGISTRY: + return CORPUS_METRIC_REGISTRY[metric_config.metric]( + samples, **metric_config.args + ) + else: + raise ValueError(f"Unknown metric: {metric_config.metric}") diff --git a/ultravox/evaluation/eval_types.py b/ultravox/evaluation/eval_types.py index d2bddc11..7a28edb8 100644 --- a/ultravox/evaluation/eval_types.py +++ b/ultravox/evaluation/eval_types.py @@ -1,15 +1,26 @@ import dataclasses -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import dataclasses_json +from pydantic import BaseModel + +# Eval config for a single metric, added to the dataset config +class EvalConfig(BaseModel): + metric: str + args: Dict[str, Any] = dataclasses.field(default_factory=dict) + + class Config: + extra = "forbid" + # do not allow undefined parameters @dataclasses.dataclass class Sample(dataclasses_json.DataClassJsonMixin): + index: int # index of the sample in the dataset, used for preserving order after ddp all_gather question: str - generated_answer: str - expected_answer: str - history: List[Dict[str, str]] = dataclasses.field(default_factory=list) + hypothesis: str + reference: str + history: List[Dict[str, Any]] = dataclasses.field(default_factory=list) @dataclasses.dataclass @@ -45,4 +56,4 @@ class BleuResult: score: float -Result = Union[InstructResult, WerResult, ExactMatchResult] +Result = Union[InstructResult, WerResult, ExactMatchResult, BleuResult] diff --git a/ultravox/evaluation/gpt_eval_conv.py b/ultravox/evaluation/gpt_eval_conv.py index 066d0fd8..a2898c15 100644 --- a/ultravox/evaluation/gpt_eval_conv.py +++ b/ultravox/evaluation/gpt_eval_conv.py @@ -17,8 +17,8 @@ {%- for turn in history + [ {"role": "user", "content": question} ] %} {% if turn["role"] == "user" %}A{% else %}B{% endif %}: {{ turn["content"] }} {% endfor %} - Model (as B): {{ generated_answer }} - Correct: {{ expected_answer }} + Model (as B): {{ hypothesis }} + Correct: {{ reference }} """ diff --git a/ultravox/evaluation/gpt_eval_test.py b/ultravox/evaluation/gpt_eval_test.py index aa45f904..c185048a 100644 --- a/ultravox/evaluation/gpt_eval_test.py +++ b/ultravox/evaluation/gpt_eval_test.py @@ -15,8 +15,8 @@ def test_evaluate_conversation(): {"role": "assistant", "content": "T2"}, ], question="T3", - generated_answer="T4", - expected_answer="EXP", + hypothesis="T4", + reference="EXP", ) expected_turns = "A: T1\n\nB: T2\n\nA: T3\n\nModel (as B): T4\nCorrect: EXP" diff --git a/ultravox/evaluation/string_based.py b/ultravox/evaluation/string_based.py deleted file mode 100644 index d835ead0..00000000 --- a/ultravox/evaluation/string_based.py +++ /dev/null @@ -1,39 +0,0 @@ -import re - -import sacrebleu - -from ultravox.evaluation import eval_types - - -def match_last_word(sample: eval_types.Sample) -> eval_types.ExactMatchResult: - last_words = re.findall(r"\b\w+\b(?=\W*$)", sample.generated_answer.lower()) - expected_tf = re.findall(r"\b\w+\b(?=\W*$)", sample.expected_answer.lower())[-1] - - if not last_words: - return eval_types.ExactMatchResult(score=0, reason="No last word found") - - last_word: str = last_words[-1] - if last_word in ["yes", "true"]: - last_word = "true" - elif last_word in ["no", "false"]: - last_word = "false" - else: - return eval_types.ExactMatchResult(score=0, reason="Last word not true/false") - - return eval_types.ExactMatchResult( - score=last_word == expected_tf, reason="exact_match check" - ) - - -def bleu(sample: eval_types.Sample) -> eval_types.BleuResult: - """ - Compute BLEU score for a single sample. - - Note: BLEU is supposed to be computed on a corpus level, not on a single sample. - As such, reported values here might not be easily comparable to other metrics. - """ - score = sacrebleu.sentence_bleu( - hypothesis=sample.generated_answer, - references=[sample.expected_answer], - ).score - return eval_types.BleuResult(score=score) diff --git a/ultravox/evaluation/string_evals.py b/ultravox/evaluation/string_evals.py new file mode 100644 index 00000000..75de4b55 --- /dev/null +++ b/ultravox/evaluation/string_evals.py @@ -0,0 +1,62 @@ +import re +from typing import List + +import jiwer +import sacrebleu + +from ultravox.evaluation import eval_types + + +def wer(sample: eval_types.Sample): + transforms = jiwer.Compose( + [ + jiwer.ExpandCommonEnglishContractions(), + jiwer.RemoveEmptyStrings(), + jiwer.ToLowerCase(), + jiwer.RemoveMultipleSpaces(), + jiwer.Strip(), + jiwer.RemovePunctuation(), + jiwer.ReduceToListOfListOfWords(), + ] + ) + + score = jiwer.wer( + [sample.reference], + [sample.hypothesis], + truth_transform=transforms, + hypothesis_transform=transforms, + ) + + return eval_types.WerResult(score=min(1.0, score)) + + +def match_last_word(sample: eval_types.Sample) -> eval_types.ExactMatchResult: + last_words = re.findall(r"\b\w+\b(?=\W*$)", sample.hypothesis.lower()) + expected_tf = re.findall(r"\b\w+\b(?=\W*$)", sample.reference.lower())[-1] + + if not last_words: + return eval_types.ExactMatchResult(score=0, reason="No last word found") + + last_word: str = last_words[-1] + if last_word in ["yes", "true"]: + last_word = "true" + elif last_word in ["no", "false"]: + last_word = "false" + else: + return eval_types.ExactMatchResult(score=0, reason="Last word not true/false") + + return eval_types.ExactMatchResult( + score=last_word == expected_tf, reason="exact_match check" + ) + + +def bleu(samples: List[eval_types.Sample], **kwargs) -> eval_types.BleuResult: + """ + Compute corpus BLEU score for a list of samples. + """ + references = [[sample.reference for sample in samples]] + hypotheses = [sample.hypothesis for sample in samples] + score = sacrebleu.corpus_bleu( + hypotheses=hypotheses, references=references, **kwargs + ).score + return eval_types.BleuResult(score=score) diff --git a/ultravox/evaluation/wer.py b/ultravox/evaluation/wer.py deleted file mode 100644 index cdba9269..00000000 --- a/ultravox/evaluation/wer.py +++ /dev/null @@ -1,33 +0,0 @@ -# compute WER comparing a set of references with a set of hypotheses - -import jiwer - -from ultravox.evaluation import eval_types - - -def compute_wer(references, hypotheses): - transforms = jiwer.Compose( - [ - jiwer.ExpandCommonEnglishContractions(), - jiwer.RemoveEmptyStrings(), - jiwer.ToLowerCase(), - jiwer.RemoveMultipleSpaces(), - jiwer.Strip(), - jiwer.RemovePunctuation(), - jiwer.ReduceToListOfListOfWords(), - ] - ) - - wer = jiwer.wer( - references, - hypotheses, - truth_transform=transforms, - hypothesis_transform=transforms, - ) - - return wer - - -def evaluate_answer_asr(sample: eval_types.Sample): - wer_rate = compute_wer([sample.expected_answer], [sample.generated_answer]) - return eval_types.WerResult(score=min(1.0, wer_rate)) diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index f117c6ee..28e50ead 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -38,7 +38,6 @@ def __init__( ) self.data_collator = datasets.DataCollatorForSeq2SeqWithAudio( tokenizer=self.tokenizer, - include_alt_fields=False, ) assert self.tokenizer.padding_side == "left" @@ -61,18 +60,18 @@ def _get_sample_with_past( def _build_past_messages( self, query_messages: List[Dict[str, str]], - audio_token_len: int, + num_audio_tokens: int, response_content: str, ) -> List[Dict[str, str]]: messages = copy.copy(query_messages) - if audio_token_len > 0: + if num_audio_tokens > 0: user_content = messages[-1]["content"] if user_content.count("<|audio|>") != 1: raise ValueError( f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}" ) messages[-1]["content"] = user_content.replace( - "<|audio|>", self.tokenizer.eos_token * audio_token_len + "<|audio|>", self.tokenizer.eos_token * num_audio_tokens ) messages.append({"role": "assistant", "content": response_content}) return messages @@ -93,9 +92,11 @@ def infer( output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True) output_len = len(output_tokens) if self.conversation_mode: - audio_token_len = inputs.get("audio_token_len", [0])[0] + num_audio_tokens = ( + output.past_key_values[0][0].shape[2] - input_len - output_len + ) past_messages = self._build_past_messages( - extended_sample.messages, audio_token_len, output_text + extended_sample.messages, num_audio_tokens, output_text ) self.update_conversation(past_messages, output.past_key_values) return base.VoiceOutput(output_text, input_len, output_len) @@ -114,6 +115,12 @@ def infer_batch( input[key] = val.squeeze(0) tensors = self.data_collator(inputs) + # Move non-None tensors to the same device as the model + tensors = { + k: v.to(self.model.device) if v is not None else v + for k, v in tensors.items() + } + input_len = tensors["input_ids"].shape[1] output_batch = self._generate( tensors, max_tokens, temperature, return_dict_in_generate=False @@ -135,7 +142,7 @@ def infer_stream( ) -> base.InferenceGenerator: extended_sample = self._get_sample_with_past(sample) inputs = self._dataproc(extended_sample) - input_tokens = inputs["input_ids"].shape[1] + input_len = inputs["input_ids"].shape[1] streamer = transformers.TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) @@ -152,21 +159,24 @@ def thunk(f: futures.Future): thread = threading.Thread(target=thunk, args=(future,)) thread.start() output_text = "" - output_token_len = 0 for chunk in streamer: if chunk: output_text += chunk - output_token_len += 1 yield base.InferenceChunk(chunk) thread.join() + output_len = len( + self.tokenizer(output_text, add_special_tokens=False).input_ids + ) output = future.result() if self.conversation_mode: - audio_token_len = inputs.get("audio_token_len", [0])[0] + num_audio_tokens = ( + output.past_key_values[0][0].shape[2] - input_len - output_len + ) past_messages = self._build_past_messages( - extended_sample.messages, audio_token_len, output_text + extended_sample.messages, num_audio_tokens, output_text ) self.update_conversation(past_messages, output.past_key_values) - yield base.InferenceStats(input_tokens, output_token_len) + yield base.InferenceStats(input_len, output_len) def _dataproc(self, sample: datasets.VoiceSample): text_input = self.tokenizer.apply_chat_template( @@ -218,8 +228,10 @@ def _generate( do_sample = temperature is not None terminators = [self.tokenizer.eos_token_id] - if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: - terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) + if "" in self.tokenizer.added_tokens_encoder: + terminators.append( + self.tokenizer.convert_tokens_to_ids("<|end_header_id|>") + ) return self.model.generate( **inputs, diff --git a/ultravox/inference/ultravox_infer.py b/ultravox/inference/ultravox_infer.py index 6765ece1..208bacdd 100644 --- a/ultravox/inference/ultravox_infer.py +++ b/ultravox/inference/ultravox_infer.py @@ -3,10 +3,10 @@ import transformers from ultravox.inference import infer -from ultravox.inference import utils from ultravox.model import ultravox_model from ultravox.model import ultravox_processing from ultravox.model import wandb_utils +from ultravox.utils import device_helpers class UltravoxInference(infer.LocalInference): @@ -32,8 +32,12 @@ def __init__( data_type: data type to use for the model conversation_mode: if true, keep track of past messages in a conversation """ - device = device or utils.default_device() - dtype = utils.get_dtype(data_type) if data_type else utils.default_dtype() + device = device or device_helpers.default_device() + dtype = ( + device_helpers.get_dtype(data_type) + if data_type + else device_helpers.default_dtype() + ) if wandb_utils.is_wandb_url(model_path): model_path = wandb_utils.download_model_from_wandb(model_path) model = ultravox_model.UltravoxModel.from_pretrained( diff --git a/ultravox/inference/utils.py b/ultravox/inference/utils.py deleted file mode 100644 index 5f04ff8c..00000000 --- a/ultravox/inference/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch - - -def default_device(): - return ( - "cuda" - if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" - ) - - -def default_dtype(): - # macOS Sonoma 14 enabled bfloat16 on MPS. - return ( - torch.bfloat16 - if torch.cuda.is_available() or torch.backends.mps.is_available() - else torch.float16 - ) - - -def get_dtype(data_type: str): - return ( - torch.bfloat16 - if data_type == "bfloat16" - else torch.float16 if data_type == "float16" else torch.float32 - ) diff --git a/ultravox/tools/data_tool.py b/ultravox/tools/data_tool.py index bc7a565a..262257ba 100644 --- a/ultravox/tools/data_tool.py +++ b/ultravox/tools/data_tool.py @@ -29,11 +29,15 @@ def main(args: argparse.Namespace): num_prompts=args.num_prompts, shuffle=args.shuffle, use_mds=args.mds, - split=args.data_split, ) if args.seed is not None: data_args.shuffle_seed = args.seed - data_sets = [datasets.create_dataset(ds, data_args) for ds in args.data_sets] + data_sets = [ + datasets.create_dataset( + data_args, datasets.DatasetConfig(type=ds_type, splits=[args.data_split]) + ) + for ds_type in args.data_sets + ] out_set = datasets.Range(datasets.InterleaveDataset(data_sets), args.num_samples) for i, sample in enumerate(out_set): print(f"--- Sample {i} ---") diff --git a/ultravox/tools/eval_tool.py b/ultravox/tools/eval_tool.py index 3334983d..82a98c6f 100644 --- a/ultravox/tools/eval_tool.py +++ b/ultravox/tools/eval_tool.py @@ -1,35 +1,131 @@ -import argparse +#!/usr/bin/env python import dataclasses -from typing import IO +import datetime +import logging +import os +import sys +from pathlib import Path +from typing import List, Optional import simple_parsing +import torch +import torch.distributed as dist +import wandb +from ultravox.data import datasets from ultravox.evaluation import eval -from ultravox.evaluation import eval_types +from ultravox.inference import ultravox_infer +from ultravox.training import ddp_utils +from ultravox.utils import device_helpers +from ultravox.utils import string_helpers + +logging.basicConfig(level=logging.INFO) @dataclasses.dataclass class EvalArgs: - # Path to the audio file - file: IO = simple_parsing.field(type=argparse.FileType("rb"), alias="-f") - # Metric to use for evaluation (e.g., "asr" or "boolq") - metric: str = simple_parsing.field(default="asr", alias="-m") - # Verbose output - verbose: bool = simple_parsing.field(default=False, alias="-v") - - -def main(args: EvalArgs): - scores = [] - for i, line in enumerate(args.file.readlines()): - sample = eval_types.Sample.from_json(line) - result = eval.evaluate_answer(sample, metric=args.metric) - assert result.score is not None, "Rating failed." - scores.append(result.score) - average = sum(scores) / len(scores) - print(f"{i}: score={result.score:.2f} average={average:.2f}") - if args.verbose and isinstance(result, eval_types.InstructResult): - print(f" reason={result.reason}") + model_path: str = simple_parsing.field() + """Model ID to use for the model""" + + eval_dataset_configs: List[datasets.DatasetConfig] = dataclasses.field( + default_factory=list + ) + """List of evaluation dataset configurations""" + + eval_dataset_args: datasets.VoiceDatasetArgs = dataclasses.field( + default_factory=datasets.VoiceDatasetArgs + ) + """Global arguments for the evaluation dataset""" + + device: str = device_helpers.default_device() + """Device to use for training (e.g., 'cuda', 'cpu', 'mps')""" + + data_type: str = device_helpers.default_dtype_str() + """Data type to use for training (e.g., 'bfloat16', 'float16', 'float32')""" + + exp_name: Optional[str] = simple_parsing.field(default=None) + """The experiment name""" + + output_dir: Optional[Path] = simple_parsing.field(default=None) + """Output directory""" + + report_logs_to: List[str] = simple_parsing.field(default_factory=list) + """Whether to report results to wandb""" + + def __post_init__(self): + self.eval_dataset_configs = [ + datasets.DatasetConfig(**config) for config in self.eval_dataset_configs + ] + + if self.data_type not in ["bfloat16", "float16", "float32", None]: + raise ValueError( + f"Invalid data type: {self.data_type}. Please specify one of the following: bfloat16, float16, float32." + ) + if self.device is None: + self.device = device_helpers.default_device() + + if self.output_dir is None: + self.output_dir = Path("runs") / self.exp_name + self.output_dir = Path(self.output_dir) + + if self.exp_name is None: + self.exp_name = datetime.datetime.now().strftime("exp--%Y-%m-%d--%H-%M-%S") + + +def main(): + args = simple_parsing.parse( + config_class=EvalArgs, + add_config_path_arg=True, + args=[string_helpers.fix_hyphens(arg) for arg in sys.argv[1:]], + ) + + world_size = device_helpers.get_world_size() + local_rank = device_helpers.get_local_rank() + device = torch.device(args.device, index=local_rank) + if world_size > 1: + dist.init_process_group(backend="gloo") + + if local_rank == 0: + args.output_dir.mkdir(parents=True, exist_ok=True) + if "wandb" in args.report_logs_to: + wandb.init( + project=os.getenv("WANDB_PROJECT", "ultravox"), + config=dataclasses.asdict(args), + name=args.exp_name, + dir="runs", + ) + + with ddp_utils.run_on_master_first(local_rank == 0): + inference = ultravox_infer.UltravoxInference( + args.model_path, + device=device, + data_type=device_helpers.get_dtype(args.data_type), + ) + + metrics, output_files = eval.run_infer( + inference, + args.eval_dataset_args, + args.eval_dataset_configs, + world_size, + local_rank, + ) + + if local_rank == 0: + if "wandb" in args.report_logs_to: + wandb.log(metrics) + for output_file in output_files: + wandb.save(output_file) + wandb.finish() + else: + logging.info(f"Evaluation Scores:\n") + for metric, score in metrics.items(): + logging.info(f" {metric}: {score}") + logging.info("Output Files:\n") + for output_file in output_files: + logging.info(f" {output_file}") + if world_size > 1: + dist.destroy_process_group() if __name__ == "__main__": - main(simple_parsing.parse(EvalArgs)) + main() diff --git a/ultravox/tools/infer_tool.py b/ultravox/tools/infer_tool.py deleted file mode 100644 index 36ea4a0d..00000000 --- a/ultravox/tools/infer_tool.py +++ /dev/null @@ -1,255 +0,0 @@ -#!/usr/bin/env python -import argparse -import dataclasses -import json -import os -import time -from typing import IO, List, Optional - -import numpy as np -import simple_parsing -from torch.utils import data as data_utils - -from ultravox.data import datasets -from ultravox.evaluation import eval -from ultravox.evaluation import eval_types -from ultravox.inference import base -from ultravox.tools import infer_api - -# There are two default modes for this tool, agent mode and ASR mode. -# In agent mode, the answer is a response to the input content and cannot be -# directly compared to the expected answer. In ASR mode, the answer is a -# transcription of the audio content and the tool can perfom a WER calculation. -# Remember to set the --asr flag when using an ASR input. -DEFAULT_PROMPT = "Listen to <|audio|> and respond to it" -DEFAULT_ASR_PROMPT = "Transcribe\n<|audio|>" - - -@dataclasses.dataclass -class InferArgs: - # Model ID to use for the model - model: str = simple_parsing.field(default="fixie-ai/ultravox-v0_2", alias="-m") - # Path to the audio file - audio_file: Optional[IO] = simple_parsing.field( - default=None, type=argparse.FileType("rb"), alias="-f" - ) - # Prompt to use for inference - prompt: Optional[str] = None - # Inference the model using only the text input or transcript, without audio - text_only: bool = False - # Use ASR for the prompt and compute WER - asr: bool = False - # URL to use for inference - url: Optional[str] = simple_parsing.field(default=None, alias="-u") - # Audio processor ID to use - audio_processor: Optional[str] = None - # Tokenizer ID to use - tokenizer: Optional[str] = None - # Data sets to use for inference - data_sets: Optional[List[str]] = simple_parsing.field(default=None, alias="-d") - # Which dataset split to use - data_split: datasets.DatasetSplit = simple_parsing.field( - default=datasets.DatasetSplit.VALIDATION, alias="-s" - ) - # Directory for existing data - data_dir: Optional[str] = None - # Use dataset context - context: bool = False - # Load datasets using MDS - mds: bool = False - # Number of dataset samples to process - num_samples: int = simple_parsing.field(default=1, alias="-n") - # Shuffle the dataset - shuffle: bool = False - # Seed for shuffling - seed: Optional[int] = None - # Device to use for inference - device: Optional[str] = simple_parsing.field(default=None, alias="-D") - # Data type to use for the model - data_type: Optional[str] = None - # Temperature for sampling - temperature: Optional[float] = simple_parsing.field(default=None, alias="-t") - # Maximum tokens to generate - max_tokens: Optional[int] = simple_parsing.field(default=None, alias="-T") - # Evaluate the generated answer - eval: bool = simple_parsing.field(default=False, alias="-e") - # Verbose output - verbose: bool = simple_parsing.field(default=False, alias="-v") - # JSON output - json: bool = simple_parsing.field(default=False) - # Batch size - batch_size: Optional[int] = simple_parsing.field(default=1, alias="-b") - - def __post_init__(self): - if self.prompt and self.prompt.startswith("@"): - with open(self.prompt[1:], "r") as f: - self.prompt = f.read() - - -def run_tui( - index: int, - inference: base.VoiceInference, - sample: datasets.VoiceSample, - args: InferArgs, - expected_response: Optional[str] = None, - scores: Optional[List[float]] = None, -): - if index >= 0: - print(f"--- Sample {index} ---") - messages = sample.messages - question_message = messages[-2] if len(messages) > 1 else messages[-1] - transcript = f' ["{sample.audio_transcript}"]' if sample.audio_transcript else "" - print(f"Q: {question_message['content']}{transcript}") - print(f"A: ", end="") - start_time = time.time() - first_token_time = None - text = "" - stats = None - - # Run streaming inference and print the output as it arrives. - stream = inference.infer_stream( - sample, - max_tokens=args.max_tokens, - temperature=args.temperature, - ) - for msg in stream: - if isinstance(msg, base.InferenceChunk): - if first_token_time is None: - first_token_time = time.time() - text += msg.text - print(msg.text, end="", flush=True) - elif isinstance(msg, base.InferenceStats): - stats = msg - if first_token_time is None or stats is None: - raise ValueError("No tokens received") - - # If we're in verbose mode, print some stats about the inference. - if args.verbose: - ttft = first_token_time - start_time - total_time = time.time() - start_time - tokens = stats.output_tokens - tps = tokens / (total_time - ttft) - print( - f" [ttft: {ttft:.2f} s, tok: {tokens}, tps: {tps:.2f}, tot: {total_time:.2f} s]", - end="", - ) - - # Print the expected response (and eval if desired). - print() - if expected_response is not None: - eval_str = "" - if scores is not None: - assert args.data_sets - ds_name = args.data_sets[0] - eval_sample = eval_types.Sample( - sample.audio_transcript or question_message["content"], - expected_answer=expected_response, - generated_answer=text, - ) - eval_metric = ( - "asr" if args.asr else "boolq" if ds_name == "boolq" else "instruct" - ) - result = eval.evaluate_answer(eval_sample, eval_metric) - if result.score is not None: - scores.append(result.score) - eval_name = "score" - reason_str = "" - mean = np.mean(scores) - if isinstance(result, eval_types.WerResult): - eval_name = "wer" - elif isinstance(result, eval_types.InstructResult) and args.verbose: - reason_str = f" ({result.reason})" - eval_str = ( - f" [{eval_name}: {result.score:.2f}{reason_str}, avg: {mean:.2f}]" - ) - else: - eval_str = " [eval failed]" - print(f"X: {expected_response}{eval_str}") - - -def oneshot_infer(inference: base.VoiceInference, args: InferArgs): - prompt = args.prompt or (DEFAULT_ASR_PROMPT if args.asr else DEFAULT_PROMPT) - if args.audio_file is not None: - sample = datasets.VoiceSample.from_prompt_and_buf( - prompt, args.audio_file.read() - ) - else: - sample = datasets.VoiceSample.from_prompt(prompt) - run_tui(-1, inference, sample, args) - - -def dataset_infer(inference: base.VoiceInference, args: InferArgs): - assert args.data_sets, "At least one data set must be provided" - ds_args = datasets.VoiceDatasetArgs( - data_dir=args.data_dir, - prompt=args.prompt, - include_audio=not args.text_only, - include_context=args.context, - shuffle=args.shuffle, - use_mds=args.mds, - split=args.data_split, - ) - if args.seed is not None: - ds_args.shuffle_seed = args.seed - ds = datasets.create_dataset(args.data_sets[0], ds_args) - - if args.json: - dl = data_utils.DataLoader( - datasets.Range(ds, args.num_samples), - batch_size=args.batch_size, - collate_fn=lambda x: x, - ) - sample_index = 0 - for input_batch in dl: - expected_answers = [ - sample.messages.pop()["content"] for sample in input_batch - ] - output_batch = inference.infer_batch( - input_batch, - max_tokens=args.max_tokens, - temperature=args.temperature, - ) - for sample, generated, expected in zip( - input_batch, output_batch, expected_answers - ): - output = { - "index": sample_index, - "question": sample.audio_transcript, - "expected_answer": expected, - "generated_answer": generated.text, - } - sample_index += 1 - print(json.dumps(output)) - else: - scores: List[float] = [] - for i, sample in enumerate(datasets.Range(ds, args.num_samples)): - # Store the answer for JSON output. - expected_answer = sample.messages[-1]["content"] - # Drop any assistant response from the sample. - sample.messages = sample.messages[:-1] - run_tui(i, inference, sample, args, expected_answer, scores) - - -def main(args: InferArgs): - if args.url is not None: - api_key = os.environ.get("ULTRAVOX_API_KEY") - inference = infer_api.create_inference(args.url, args.model, api_key) - else: - # Only load our local inference module if we're not using the API. - from ultravox.inference import ultravox_infer - - inference = ultravox_infer.UltravoxInference( - args.model, - tokenizer_id=args.tokenizer, - audio_processor_id=args.audio_processor, - device=args.device, - data_type=args.data_type, - ) - if args.data_sets is None: - oneshot_infer(inference, args) - else: - dataset_infer(inference, args) - - -if __name__ == "__main__": - main(simple_parsing.parse(InferArgs)) diff --git a/ultravox/tools/push_to_hub.py b/ultravox/tools/push_to_hub.py index 01945321..b9ebe7f6 100644 --- a/ultravox/tools/push_to_hub.py +++ b/ultravox/tools/push_to_hub.py @@ -49,7 +49,10 @@ def main(args: UploadToHubArgs): loaded_pipe = transformers.pipeline( model=args.hf_upload_model, trust_remote_code=True ) - ds = datasets.BoolQDataset(datasets.VoiceDatasetArgs()) + ds = datasets.BoolQDataset( + datasets.VoiceDatasetArgs(), + datasets.DatasetConfig(type="boolq", num_samples=10), + ) sample = next(iter(ds)) generated = loaded_pipe( {"audio": sample.audio, "turns": sample.messages[:-1]}, max_new_tokens=10 diff --git a/ultravox/tools/score_tool.py b/ultravox/tools/score_tool.py new file mode 100644 index 00000000..5f227ae6 --- /dev/null +++ b/ultravox/tools/score_tool.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +import dataclasses +import json +import sys +from pathlib import Path +from typing import Dict + +import simple_parsing + +from ultravox.evaluation import eval_helpers +from ultravox.evaluation import eval_types +from ultravox.utils import string_helpers + + +@dataclasses.dataclass +class ScoreArgs: + # Path to JSON file containing samples + input: Path = simple_parsing.field(alias="-i") + # Metric to use for evaluation + metric: str = simple_parsing.field(alias="-m") + # Additional arguments for the metric configuration + args: Dict[str, str] = simple_parsing.field(alias="-a", default_factory=dict) + + +# Used to score a file of ultravox output samples containing reference and hypothesis text, using the specified metric and any additional arguments. +def main(): + args = simple_parsing.parse( + config_class=ScoreArgs, + args=[string_helpers.fix_hyphens(arg) for arg in sys.argv[1:]], + ) + + # Load samples from JSON file + with open(args.input, "r") as f: + samples_data = json.load(f) + + samples = [eval_types.Sample(**sample) for sample in samples_data] + + # Create EvalConfig with args + metric_config = eval_types.EvalConfig(metric=args.metric, args=args.args) + + # Evaluate samples + result = eval_helpers.evaluate_answers(samples, metric_config) + + # Print result + print(f"input: {args.input}") + print(f"args: {args.args}") + print(f"{args.metric}: {result.score:.2f}") + + +if __name__ == "__main__": + main() diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index baafc933..6568e418 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -8,104 +8,183 @@ from typing import Any, Dict, List, Optional import simple_parsing -import torch -from ultravox.data import dataset_config from ultravox.data import datasets from ultravox.model import ultravox_config +from ultravox.utils import device_helpers, string_helpers @dataclasses.dataclass class TrainConfig: - data_sets: List[str] - val_sets: List[str] - # language model to use + """Configuration class for training settings.""" + text_model: str - # audio encoder model to use + """Language model to use; could be a huggingface model id, wandb path, or local path.""" + audio_model: str + """Audio encoder model to use; could be a huggingface model id, wandb path, or local path.""" + + adapter_type: ultravox_config.AdapterType = ultravox_config.AdapterType.STACKING + """Type of adapter to use for fine-tuning. Defaults to STACKING.""" + + adapter_config: Optional[Dict[str, Any]] = None + """Optional configuration dictionary for the adapter. If None, default settings will be used.""" + + train_dataset_configs: List[datasets.DatasetConfig] = dataclasses.field( + default_factory=list + ) + """ + List of training dataset configurations. + Initially parsed as dictionaries (due to limitations of simple_parsing), then converted to DatasetConfig objects in __post_init__. + """ + + val_dataset_configs: List[datasets.DatasetConfig] = dataclasses.field( + default_factory=list + ) + """ + List of validation dataset configurations. + Initially parsed as dictionaries (due to limitations of simple_parsing), then converted to DatasetConfig objects in __post_init__. + """ + + eval_dataset_configs: List[datasets.DatasetConfig] = dataclasses.field( + default_factory=list + ) + """ + List of evaluation dataset configurations. + Initially parsed as dictionaries (due to limitations of simple_parsing), then converted to DatasetConfig objects in __post_init__. + """ + + train_dataset_args: datasets.VoiceDatasetArgs = dataclasses.field( + default_factory=datasets.VoiceDatasetArgs + ) + """Global arguments for the training dataset.""" + + val_dataset_args: datasets.VoiceDatasetArgs = dataclasses.field( + default_factory=datasets.VoiceDatasetArgs + ) + """Global arguments for the validation dataset.""" - # The data_dicts field complements data_sets, allowing for the inclusion of - # new datasets in the config. - # - # Due to simple_parsing's lack of support for containers of dataclass types, - # we first parse the data_dicts as a list of dictionaries. After parsing, - # we convert these dictionaries to DataDictConfig objects using Pydantic - # to enforce type constraints and validation, in the __post_init__ method. - data_dicts: Optional[List[Dict[str, Any]]] = None + eval_dataset_args: datasets.VoiceDatasetArgs = dataclasses.field( + default_factory=datasets.VoiceDatasetArgs + ) + """Global Arguments for the evaluation dataset.""" do_train: bool = True + """Whether to perform training.""" + do_eval: bool = True + """Whether to perform evaluation.""" - # In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED - data_dir: Optional[str] = None - mds: bool = False - num_samples: Optional[int] = None - val_num_samples: int = 100 - eval_num_samples: int = 100 - eval_max_new_tokens: Optional[int] = None - eval_num_procs: int = 8 - eval_text_only: bool = False - num_prompts: int = 1 - # number of data loader workers - num_workers: int = 8 if torch.cuda.is_available() else 1 + """ + The stop strategy for InterleaveDataset when combining multiple datasets for training, + choose from last_exhausted (default), first_exhausted, or never_stop (rely on max_steps or num_epochs to stop, and should be used as the default. + """ + + num_workers: int = device_helpers.get_world_size() + """Number of data loader workers.""" + train_on_inputs: bool = False - shuffle_data: bool = False - # Maximum audio duration in seconds. Samples with longer audio will be skipped. - # This is usually due to GPU memory constraints and also dependends on the dataset. - max_audio_duration_secs: Optional[float] = None + """Whether to train on inputs.""" verbose: bool = False + """Whether to enable verbose output.""" + + device: str = device_helpers.default_device() + """Device to use for training (e.g., 'cuda', 'cpu', 'mps').""" + + data_type: str = device_helpers.default_dtype_str() + """Data type to use for training (e.g., 'bfloat16', 'float16', 'float32').""" - device: str = "cuda" - data_type: str = "bfloat16" - # Path to load the model from. Can be local path, HF hub model_id, or W&B artifact model_load_dir: Optional[str] = None + """ + Path to load pretrained ultravox model from. Can be local path, HF hub model_id, or W&B artifact. + """ + text_model_lora_config: Optional[ultravox_config.LoraConfigSimplified] = None + """LoRA configuration for the text model.""" + audio_model_lora_config: Optional[ultravox_config.LoraConfigSimplified] = None + """LoRA configuration for the audio model.""" + disable_layerdrop: bool = False + """Whether to disable layerdrop.""" - # The experiment name exp_name: Optional[str] = None + """The experiment name.""" + output_dir: Optional[Path] = None + """Directory to save output files.""" + logs_dir: Optional[Path] = None + """Directory to save log files.""" + optimizer: str = "adamw_torch" + """Optimizer to use for training.""" + num_epochs: int = 1 + """Number of training epochs, only used when max_steps is 0.""" + max_steps: int = 0 + """Maximum number of training steps, if 0, use num_epochs.""" + val_steps: Optional[int] = None + """Number of steps between validations.""" + save_steps: float = 0 + """Number of steps between model saves.""" + logging_steps: int = 1 + """Number of steps between logging.""" + grad_accum_steps: int = 1 + """Number of gradient accumulation steps.""" + val_accum_steps: int = 1 - batch_size: int = 2 + """Number of validation accumulation steps.""" + lr: float = 1e-5 + """Learning rate.""" + lr_scheduler: str = "cosine" + """Learning rate scheduler type.""" + lr_warmup_steps: int = 0 + """Number of learning rate warmup steps.""" + weight_decay: float = 0.0 + """Weight decay for optimizer.""" + seed: int = 42 - shuffle_seed: int = 42 - # Experiment logging destinations: tensorboard, wandb, neptune, mlflow, etc + """Random seed for reproducibility.""" + report_logs_to: List[str] = simple_parsing.list_field("tensorboard") - # A list of tags for filtering runs. Only used for wandb. + """Experiment logging destinations: tensorboard, wandb, etc.""" + run_tags: List[str] = simple_parsing.list_field() + """A list of tags for filtering runs. Only used for wandb.""" - # loss function to use - loss_config: Optional[ultravox_config.LossConfig] = None + loss_config: ultravox_config.LossConfig = dataclasses.field( + default_factory=ultravox_config.LossConfig + ) + """Configuration for the loss function.""" def __post_init__(self): - if self.data_dicts: - self.data_dicts = [ - dataset_config.DataDictConfig(**data_dict) - for data_dict in self.data_dicts - ] - # For now, self.data_dicts is a hack to allow for the inclusion of new datasets using the - # GenericVoiceDataset class, without changing how existing datasets are specified in - # self.data_sets. In the future, all datasets will be updated to use the DataDictConfig class. - self.data_sets.extend(self.data_dicts) + self.train_dataset_configs = [ + datasets.DatasetConfig(**data_dict) + for data_dict in self.train_dataset_configs + ] + self.val_dataset_configs = [ + datasets.DatasetConfig(**data_dict) + for data_dict in self.val_dataset_configs + ] + self.eval_dataset_configs = [ + datasets.DatasetConfig(**data_dict) + for data_dict in self.eval_dataset_configs + ] assert self.data_type in ["bfloat16", "float16", "float32"] - if self.device == "cuda" and not torch.cuda.is_available(): - self.device = "mps" if torch.backends.mps.is_available() else "cpu" + assert self.device in ["cuda", "cpu", "mps"] if self.device != "cuda": if self.data_type == "bfloat16": self.data_type = "float32" @@ -117,12 +196,16 @@ def __post_init__(self): if self.exp_name is None: self.exp_name = datetime.datetime.now().strftime("exp--%Y-%m-%d--%H-%M-%S") + sanitized_exp_name = string_helpers.sanitize_name(self.exp_name) + if sanitized_exp_name != self.exp_name: + logging.warning( + f"Experiment name {self.exp_name} has been sanitized to {sanitized_exp_name} for use as a valid file/directory/module name." + ) + self.exp_name = sanitized_exp_name + if self.output_dir is None: self.output_dir = Path("runs") / self.exp_name - # HF Pipeline gets tripped up if the path has a "." in it - self.output_dir = self.output_dir.replace(".", "--") - if self.logs_dir is None: self.logs_dir = self.output_dir / "logs" @@ -137,11 +220,6 @@ def __post_init__(self): ) self.disable_layerdrop = True - -def fix_hyphens(arg: str): - return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) - - def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig: """ Parse the command line arguments and return a TrainConfig object. @@ -156,5 +234,5 @@ def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig config_class=TrainConfig, config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"), add_config_path_arg=True, - args=[fix_hyphens(arg) for arg in args], - ) + args=[string_helpers.fix_hyphens(arg) for arg in args], + ) \ No newline at end of file diff --git a/ultravox/training/configs/asr_llama.yaml b/ultravox/training/configs/asr_llama.yaml deleted file mode 100644 index 4417d113..00000000 --- a/ultravox/training/configs/asr_llama.yaml +++ /dev/null @@ -1,3 +0,0 @@ -exp_name: "llama3_wav2vec2_asr" - -text_model: "meta-llama/Meta-Llama-3-8B-Instruct" diff --git a/ultravox/training/configs/asr_tinyllama.yaml b/ultravox/training/configs/asr_tinyllama.yaml deleted file mode 100644 index 3c3a2a85..00000000 --- a/ultravox/training/configs/asr_tinyllama.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# small scale model training: mostly for debugging purposes - -text_model: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -exp_name: "tinyllama_asr" diff --git a/ultravox/training/configs/asr_tinyllama_100s.yaml b/ultravox/training/configs/asr_tinyllama_100s.yaml deleted file mode 100644 index 0cb38740..00000000 --- a/ultravox/training/configs/asr_tinyllama_100s.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# test config for fast experimentation, only 100 steps - -text_model: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -exp_name: "tinyllama_asr_100s" - -max_steps: 100 -lr_warmup_steps: 10 diff --git a/ultravox/training/configs/llama3_whisper.yaml b/ultravox/training/configs/llama3_whisper.yaml deleted file mode 100644 index f8e77bb3..00000000 --- a/ultravox/training/configs/llama3_whisper.yaml +++ /dev/null @@ -1,6 +0,0 @@ -# SLM with ultravox & llama3 -exp_name: "llama3_whisper_s" - -# Make sure to accept the license agreement on huggingface hub -text_model: "meta-llama/Meta-Llama-3-8B-Instruct" -audio_model: "openai/whisper-small" \ No newline at end of file diff --git a/ultravox/training/configs/llama3_whisper_kd.yaml b/ultravox/training/configs/llama3_whisper_kd.yaml deleted file mode 100644 index d951f02d..00000000 --- a/ultravox/training/configs/llama3_whisper_kd.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# SLM with ultravox & llama3, trained wtih knowledge distillation. -exp_name: "llama3_whisper_s" - -# Make sure to accept the license agreement on huggingface hub -text_model: "meta-llama/Meta-Llama-3-8B-Instruct" -audio_model: "openai/whisper-small" - - -loss_config: - # Choose from ["KL_Divergence", "CrossEntropy"], default is "KL_Divergence" - loss_function: "KL_Divergence" - -# Temporarily remove heysquad_human from val_sets as it causes the training to fail. -val_sets: ["anyinstruct", "soda", "peoplespeech"] - -batch_size: 4 -max_steps: 1000 - -data_sets: [] -data_dicts: - - path: "fixie-ai/librispeech_asr" - name: "clean" - splits: - - "train.100" - - "train.360" - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - weight: 2 - num_samples: 100_000 - - path: "fixie-ai/librispeech_asr" - name: "other" - splits: - - "train.500" - user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" - assistant_template: "{{ continuation }}" - transcript_template: "{{ text }}" - weight: 1 - num_samples: 100_000 diff --git a/ultravox/training/configs/llama_whisper.yaml b/ultravox/training/configs/llama_whisper.yaml new file mode 100644 index 00000000..92b7c541 --- /dev/null +++ b/ultravox/training/configs/llama_whisper.yaml @@ -0,0 +1,102 @@ +# llama3.1-8b + whisper-medium, for development + +exp_name: "llama3.1-8b-whisper" +text_model: "meta-llama/Meta-Llama-3.1-8B-Instruct" +audio_model: "openai/whisper-medium" + +loss_config: + loss_function: "KL_Divergence" + +max_steps: 20 # x8x24 = 2,764,800 + +train_dataset_args: + shuffle: True + batch_size: 4 +train_dataset_configs: +# continuation + - alias: "librispeech-clean" + path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + active_samples: 100 + num_samples: 132_553 + user_template: "<|audio|>" + assistant_template: "" + transcript_template: "{{ text }}" + weight: 1 + - alias: "librispeech-other" + path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + num_samples: 148_688 + user_template: "<|audio|>" + assistant_template: "" + transcript_template: "{{ text }}" + weight: 1 + + +val_dataset_args: + batch_size: 4 +val_dataset_configs: + - alias: "covost2-en_zh-val" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "validation" + active_samples: 20 + num_samples: 15531 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + - alias: "covost2-zh_en-val" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "validation" + active_samples: 20 + num_samples: 4843 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + +eval_dataset_args: + batch_size: 4 + max_tokens: 20 +eval_dataset_configs: + - alias: "covost2-en_zh-test" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 20 + num_samples: 15531 + eval_config: + metric: "bleu" + args: + tokenize: "zh" + - alias: "covost2-zh_en-test" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 20 + num_samples: 4898 + eval_config: + metric: "bleu" \ No newline at end of file diff --git a/ultravox/training/configs/meta_config.yaml b/ultravox/training/configs/meta_config.yaml index d3764d29..a4f0a7e8 100644 --- a/ultravox/training/configs/meta_config.yaml +++ b/ultravox/training/configs/meta_config.yaml @@ -1,19 +1,11 @@ text_model: "meta-llama/Meta-Llama-3-8B-Instruct" audio_model: "facebook/wav2vec2-base-960h" -data_sets: ["gigaspeech"] -val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"] stop_strategy: "LAST_EXHAUSTED" train_on_inputs: False -shuffle_data: True -max_audio_duration_secs: 16 -val_num_samples: 64 val_steps: 1000 -eval_num_samples: 2000 -eval_max_new_tokens: 32 -eval_num_procs: 16 optimizer: "adamw_torch" # options: adamw_torch, adamw_bnb_8bit lr_scheduler: "cosine" # options: linear, cosine, cosine_with_restarts, etc. @@ -25,7 +17,185 @@ max_steps: 10_000 save_steps: 0.25 logging_steps: 100 -batch_size: 4 -data_type: "bfloat16" - report_logs_to: ["tensorboard", "wandb"] + +# include speech translation (zh-en, en-zh) and transcription (peoples_speech) for validation, with and without audio +val_dataset_args: + batch_size: 40 +val_dataset_configs: + - alias: "covost2-en_zh-val" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "validation" + active_samples: 200 + num_samples: 15531 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + - alias: "covost2-zh_en-val" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "validation" + active_samples: 200 + num_samples: 4843 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + - alias: "peoples_speech-val" + path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + active_samples: 200 + num_samples: 1_501_271 + user_template: "Transcribe\n<|audio|>" + assistant_template: "{{ text_proc.format_asr_text(text) }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + - alias: "covost2-en_zh-text-val" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "validation" + active_samples: 200 + num_samples: 15531 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n{{ sentence }}" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + audio_field: null + - alias: "covost2-zh_en-text-val" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "validation" + active_samples: 200 + num_samples: 4843 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n{{ sentence }}" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + audio_field: null + - alias: "peoples_speech-text-val" + path: "fixie-ai/peoples_speech" + name: "clean" + splits: + - "train" # 1_501_271 samples + active_samples: 200 + num_samples: 1_501_271 + user_template: "Transcribe\n{{ text_proc.format_asr_text(text) }}" + assistant_template: "{{ text_proc.format_asr_text(text) }}" + transcript_template: "{{ text_proc.format_asr_text(text) }}" + audio_field: null + +eval_dataset_args: + batch_size: 40 + max_tokens: 512 +eval_dataset_configs: + - alias: "covost2-en_de-test-2k" + path: "fixie-ai/covost2" + name: "en_de" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "German" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-en_zh-test-2k" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + args: + tokenize: "zh" + - alias: "covost2-en_ar-test-2k" + path: "fixie-ai/covost2" + name: "en_ar" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Arabic" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-en_ca-test-2k" + path: "fixie-ai/covost2" + name: "en_ca" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Catalan" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-es_en-test-2k" + path: "fixie-ai/covost2" + name: "es_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-ru_en-test-2k" + path: "fixie-ai/covost2" + name: "ru_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" + - alias: "covost2-zh_en-test-2k" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 2000 + num_samples: 15531 + eval_config: + metric: "bleu" \ No newline at end of file diff --git a/ultravox/training/configs/release_config.yaml b/ultravox/training/configs/release_config.yaml index 36b7a5f6..1bf46722 100644 --- a/ultravox/training/configs/release_config.yaml +++ b/ultravox/training/configs/release_config.yaml @@ -16,14 +16,18 @@ val_sets: ["anyinstruct", "soda", "peoplespeech"] batch_size: 24 max_steps: 14400 # x8x24 = 2,764,800 -data_sets: ["anyinstruct"] -data_dicts: +train_dataset_args: + shuffle: True + batch_size: 24 + +train_dataset_configs: # continuation - path: "fixie-ai/librispeech_asr" name: "clean" splits: - "train.100" # 28_539 samples - "train.360" # 104_014 samples + num_samples: 132_553 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text }}" @@ -32,6 +36,7 @@ data_dicts: name: "other" splits: - "train.500" # 148_688 samples + num_samples: 148_688 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text }}" @@ -40,6 +45,7 @@ data_dicts: name: "clean" splits: - "train" # 1_501_271 samples + num_samples: 1_501_271 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text_proc.format_asr_text(text) }}" @@ -48,6 +54,7 @@ data_dicts: name: "en" splits: - "train" # 1_101_170 samples + num_samples: 1_101_170 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ text_proc.format_asr_text(sentence) }}" @@ -56,6 +63,7 @@ data_dicts: name: "ar" splits: - "train" # 28_369 samples + num_samples: 28_369 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -64,6 +72,7 @@ data_dicts: name: "de" splits: - "train" # 589_100 samples + num_samples: 589_100 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -72,6 +81,7 @@ data_dicts: name: "es" splits: - "train" # 336_846 samples + num_samples: 336_846 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -80,6 +90,7 @@ data_dicts: name: "fr" splits: - "train" # 558_054 samples + num_samples: 558_054 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -88,6 +99,7 @@ data_dicts: name: "it" splits: - "train" # 169_771 samples + num_samples: 169_771 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -96,6 +108,7 @@ data_dicts: name: "ja" splits: - "train" # 10_039 samples + num_samples: 10_039 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -104,6 +117,7 @@ data_dicts: name: "pt" splits: - "train" # 21_968 samples + num_samples: 21_968 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -112,6 +126,7 @@ data_dicts: name: "ru" splits: - "train" # 26_377 samples + num_samples: 26_377 user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" assistant_template: "{{ continuation }}" transcript_template: "{{ sentence }}" @@ -122,6 +137,7 @@ data_dicts: splits: - "train.100" # 28_539 samples - "train.360" # 104_014 samples + num_samples: 132_553 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text }}" transcript_template: "{{ text }}" @@ -130,6 +146,7 @@ data_dicts: name: "other" splits: - "train.500" # 148_688 samples + num_samples: 148_688 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text }}" transcript_template: "{{ text }}" @@ -138,6 +155,7 @@ data_dicts: name: "clean" splits: - "train" # 1_501_271 samples + num_samples: 1_501_271 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(text) }}" transcript_template: "{{ text_proc.format_asr_text(text) }}" @@ -146,6 +164,7 @@ data_dicts: name: "en" splits: - "train" # 1_101_170 samples + num_samples: 1_101_170 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ text_proc.format_asr_text(sentence) }}" @@ -154,6 +173,7 @@ data_dicts: name: "ar" splits: - "train" # 28_369 samples + num_samples: 28_369 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -162,6 +182,7 @@ data_dicts: name: "de" splits: - "train" # 589_100 samples + num_samples: 589_100 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -170,6 +191,7 @@ data_dicts: name: "es" splits: - "train" # 336_846 samples + num_samples: 336_846 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -178,6 +200,7 @@ data_dicts: name: "fr" splits: - "train" # 558_054 samples + num_samples: 558_054 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -186,6 +209,7 @@ data_dicts: name: "it" splits: - "train" # 169_771 samples + num_samples: 169_771 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -194,6 +218,7 @@ data_dicts: name: "ja" splits: - "train" # 10_039 samples + num_samples: 10_039 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -202,6 +227,7 @@ data_dicts: name: "pt" splits: - "train" # 21_968 samples + num_samples: 21_968 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" @@ -210,6 +236,7 @@ data_dicts: name: "ru" splits: - "train" # 26_377 samples + num_samples: 26_377 user_template: "{{ dataset._get_transcribe_prompt() }}" assistant_template: "{{ text_proc.format_asr_text(sentence) }}" transcript_template: "{{ sentence }}" diff --git a/ultravox/training/configs/stage2_lora.yaml b/ultravox/training/configs/stage2_lora.yaml deleted file mode 100644 index b580d255..00000000 --- a/ultravox/training/configs/stage2_lora.yaml +++ /dev/null @@ -1,23 +0,0 @@ -exp_name: stage2_lora__gs_ai_bq_soda - -text_model_lora_config: - r: 64 # no/little change in the range [16, 64] - target_modules: ['mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj', 'v_proj', 'o_proj', 'k_proj', 'q_proj'] - -data_sets: ["gigaspeech", "anyinstruct", "boolq_extended", "soda"] - -# disable_layer_drop: True -# audio_model_lora_config: -# r: 64 -# target_modules: ['k_proj', 'q_proj', 'v_proj', 'out_proj', 'intermediate_dense', 'output_dense'] - -num_prompts: 11 - -lr: 1.e-4 # need a lower LR for LLM fine-tuning -lr_scheduler: constant_with_warmup -lr_warmup_steps: 250 -max_steps: 5_000 -save_steps: 0 - -batch_size: 2 -grad_accum_steps: 2 diff --git a/ultravox/training/configs/tinyllama_whisper.yaml b/ultravox/training/configs/tinyllama_whisper.yaml new file mode 100644 index 00000000..6672d568 --- /dev/null +++ b/ultravox/training/configs/tinyllama_whisper.yaml @@ -0,0 +1,102 @@ +# tiny llama + whisper-small, for development + +exp_name: "tinyllama-whisper" +text_model: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +audio_model: "openai/whisper-small" + +loss_config: + loss_function: "KL_Divergence" + +max_steps: 20 # x8x24 = 2,764,800 + +train_dataset_args: + shuffle: True + batch_size: 4 +train_dataset_configs: +# continuation + - alias: "librispeech-clean" + path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" # 28_539 samples + - "train.360" # 104_014 samples + active_samples: 100 + num_samples: 132_553 + user_template: "<|audio|>" + assistant_template: "" + transcript_template: "{{ text }}" + weight: 1 + - alias: "librispeech-other" + path: "fixie-ai/librispeech_asr" + name: "other" + splits: + - "train.500" # 148_688 samples + num_samples: 148_688 + user_template: "<|audio|>" + assistant_template: "" + transcript_template: "{{ text }}" + weight: 1 + + +val_dataset_args: + batch_size: 4 +val_dataset_configs: + - alias: "covost2-en_zh-val" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "validation" + active_samples: 20 + num_samples: 15531 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + - alias: "covost2-zh_en-val" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "validation" + active_samples: 20 + num_samples: 4843 + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + +eval_dataset_args: + batch_size: 4 + max_tokens: 20 +eval_dataset_configs: + - alias: "covost2-en_zh-test" + path: "fixie-ai/covost2" + name: "en_zh-CN" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "Chinese" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 20 + num_samples: 15531 + eval_config: + metric: "bleu" + args: + tokenize: "zh" + - alias: "covost2-zh_en-test" + path: "fixie-ai/covost2" + name: "zh-CN_en" + splits: + - "test" + user_template: "Please translate the text to {{target}}. Your response should only include the {{target}} translation, without any additional words:\n\n<|audio|>" + user_template_args: + target: "English" + assistant_template: "{{ translation }}" + transcript_template: "{{ sentence }}" + active_samples: 20 + num_samples: 4898 + eval_config: + metric: "bleu" \ No newline at end of file diff --git a/ultravox/training/ddp_utils.py b/ultravox/training/ddp_utils.py index ba0382a5..0f5a2b93 100644 --- a/ultravox/training/ddp_utils.py +++ b/ultravox/training/ddp_utils.py @@ -34,7 +34,11 @@ def all_gather_list(data: List[T]) -> List[T]: return data world_size = torch.distributed.get_world_size() data_list = [None] * world_size - torch.distributed.all_gather_object(data_list, data) + try: + torch.distributed.all_gather_object(data_list, data) + except Exception as e: + print(f"Error in all_gather_object: {e}", flush=True) + raise e return flatten(data_list) # type: ignore @@ -43,3 +47,18 @@ def sharded_iterator(ds: data.IterableDataset, num_shards: int, shard_index: int for i, sample in enumerate(ds): if i % num_shards == shard_index: yield sample + + +def sharded_batch_iterator( + ds: data.IterableDataset, batch_size: int, num_shards: int, shard_index: int +): + batch = [] + for idx, sample in enumerate(ds): + if idx % num_shards == shard_index: + batch.append((idx, sample)) + if len(batch) == batch_size: + yield batch + batch = [] + # Yield any remaining samples in the last incomplete batch + if batch: + yield batch diff --git a/ultravox/training/evaluation.py b/ultravox/training/evaluation.py deleted file mode 100644 index 36492d67..00000000 --- a/ultravox/training/evaluation.py +++ /dev/null @@ -1,186 +0,0 @@ -import concurrent.futures -import dataclasses -import functools -import json -import os -from typing import List, Optional - -import numpy as np -from torch.utils import data - -from ultravox.data import datasets -from ultravox.evaluation import eval -from ultravox.evaluation import eval_types -from ultravox.inference import infer -from ultravox.training import ddp_utils - - -def dataset_infer( - inference: infer.LocalInference, - ds: data.IterableDataset, - world_size: int = 1, - local_rank: int = 0, - max_new_tokens: Optional[int] = None, - temperature: Optional[float] = None, -) -> List[eval_types.Sample]: - eval_samples = [] - - for sample in ddp_utils.sharded_iterator(ds, world_size, local_rank): - # Store the original question and answer for JSON output. - question_text = sample.audio_transcript or sample.messages[-2]["content"] - expected_answer = sample.messages[-1]["content"] - # Drop any assistant response from the sample. - sample.messages = sample.messages[:-1] - history = sample.messages[:-2] - - output = inference.infer( - sample, max_tokens=max_new_tokens, temperature=temperature - ) - eval_sample = eval_types.Sample( - question=question_text, - generated_answer=output.text, - expected_answer=expected_answer, - history=history, - ) - eval_samples.append(eval_sample) - - # Gather all the samples from all the processes. - return ddp_utils.all_gather_list(eval_samples) - - -@dataclasses.dataclass -class EvalScenario: - name: str - dataset: str - metric: str - include_audio: bool = True - include_context: bool = True - new_tokens: Optional[int] = None - - -EVAL_SCENARIOS = [ - # automatic speech recognition scenarios - EvalScenario("boolq__wer", "boolq_in", "asr"), - # automatic speech translation scenarios - EvalScenario("covost2_en_de__bleu", "covost2:en_de", "bleu"), - EvalScenario("covost2_en_zh-CN__bleu", "covost2:en_zh-CN", "bleu"), - EvalScenario("covost2_es_en__bleu", "covost2:es_en", "bleu"), - EvalScenario( - "covost2_en_de__bleu__text_only", "covost2:en_de", "bleu", include_audio=False - ), - EvalScenario( - "covost2_en_zh-CN__bleu__text_only", - "covost2:en_zh-CN", - "bleu", - include_audio=False, - ), - EvalScenario( - "covost2_es_en__bleu__text_only", "covost2:es_en", "bleu", include_audio=False - ), - # SQA scenarios - EvalScenario("anyinstruct__instruct_follow", "anyinstruct", "instruct"), - EvalScenario( - "boolq__binary", "boolq_extended", "exact_match_last_word", new_tokens=128 - ), - EvalScenario( - "anyinstruct__instruct_follow__text_only", - "anyinstruct", - "instruct", - include_audio=False, - ), - EvalScenario( - "boolq__binary__text_only", - "boolq_extended", - "exact_match_last_word", - new_tokens=128, - include_audio=False, - ), - # Conversation dialogue scenarios - EvalScenario("soda__sensible_generation", "soda", "conversation", new_tokens=64), - EvalScenario( - "soda__sensible_generation__text_only", - "soda", - "conversation", - new_tokens=64, - include_audio=False, - ), -] - - -def evaluate( - inference: infer.LocalInference, - data_dir: Optional[str] = None, - num_samples: int = 200, - num_procs: int = 8, - max_new_tokens: Optional[int] = None, - temperature: Optional[float] = None, - log_dir: Optional[str] = None, -): - metrics = {} - - world_size = int(os.environ.get("WORLD_SIZE", 1)) - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - - if log_dir: - log_dir = os.path.join(log_dir, "evals") - os.makedirs(log_dir, exist_ok=True) - - for task in EVAL_SCENARIOS: - ds_args = datasets.VoiceDatasetArgs( - data_dir=data_dir, - split=datasets.DatasetSplit.VALIDATION, - include_audio=task.include_audio, - include_context=task.include_context, - ) - - ds = datasets.Range(datasets.create_dataset(task.dataset, ds_args), num_samples) - - output_samples = dataset_infer( - inference, - ds=ds, - max_new_tokens=task.new_tokens or max_new_tokens, - temperature=temperature, - world_size=world_size, - local_rank=local_rank, - ) - - if local_rank != 0: - # Only the master process should evaluate the samples. - continue - - eval_per_sample = functools.partial(eval.evaluate_answer, metric=task.metric) - - with concurrent.futures.ThreadPoolExecutor(max_workers=num_procs) as executor: - possibly_non_scores = [ - x.score for x in executor.map(eval_per_sample, output_samples) - ] - - if None in possibly_non_scores: - print(f"Failed to evaluate {task.metric} for {task.dataset}") - continue - - scores = [x for x in possibly_non_scores if x is not None] - - average = np.mean(scores) - std = np.std(scores) / np.sqrt(len(scores)) - metrics[f"eval_{task.name}"] = average - metrics[f"eval_{task.name}_std"] = std - - has_audio_str = "with" if task.include_audio else "without" - agg_score_str = f"Aggregate {task.metric} score for {task.dataset} ({has_audio_str} audio): {average:.2f} ± {std:.2f}" - print(agg_score_str) - - if log_dir: - eval_details = { - "score": average, - "confidence_interval": std, - "task_info": dataclasses.asdict(task), - "samples": [ - {**dataclasses.asdict(sample), "score": score} - for sample, score in zip(output_samples, scores) - ], - } - with open(os.path.join(log_dir, f"{task.name}.json"), "w") as f: - json.dump(eval_details, f, indent=1) - - return metrics diff --git a/ultravox/training/train.py b/ultravox/training/train.py index d74620e2..7e26c639 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -1,15 +1,12 @@ -import copy import dataclasses -import gc import glob import logging import os -import subprocess +import sys from datetime import datetime -from typing import Dict, List, Optional +from typing import Dict, List import datasets as hf_datasets -import pandas as pd import safetensors.torch import torch import torch.distributed @@ -19,6 +16,8 @@ from torch.utils import data from ultravox.data import datasets +from ultravox.evaluation import eval +from ultravox.inference import infer from ultravox.model import data_processing from ultravox.model import ultravox_config from ultravox.model import ultravox_model @@ -35,15 +34,17 @@ def prepare_dataset( train_args: config_base.TrainConfig, - dataset_names: List[str], + dataset_configs: List[datasets.DatasetConfig], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, train_on_inputs: bool, stop_strategy: datasets.StopStrategy, - num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) ) -> datasets.SizedIterableDataset: - data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] + data_sets = [ + datasets.create_dataset(data_args, dataset_config) + for dataset_config in dataset_configs + ] # If we're using epochs to train, validate the dataset length is appropriate. if train_args.max_steps == 0: for ds in data_sets: @@ -51,15 +52,16 @@ def prepare_dataset( len(ds) > 1 ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" - interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy) - ds_with_proc = data_processing.UltravoxDataproc( - interleave, + interleaved_dataset = datasets.InterleaveDataset( + data_sets, stop_strategy=stop_strategy + ) + dataset_with_proc = data_processing.UltravoxDataproc( + interleaved_dataset, processor=processor, train_on_inputs=train_on_inputs, include_alt_fields=include_alt_fields, ) - limited_ds = datasets.Range(ds_with_proc, num_samples=num_samples) - return limited_ds + return dataset_with_proc def main() -> None: @@ -74,16 +76,8 @@ def main() -> None: transformers.set_seed(args.seed) - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - is_master = local_rank == 0 - train(args) - if args.do_eval and is_master: - gc.collect() - torch.cuda.empty_cache() - evaluate(args) - def train(args: config_base.TrainConfig): world_size = int(os.environ.get("WORLD_SIZE", 1)) @@ -114,8 +108,6 @@ def train(args: config_base.TrainConfig): ) text_tokenizer.padding_side = "right" text_tokenizer.pad_token = text_tokenizer.eos_token - audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model) - processor = ultravox_processing.UltravoxProcessor(audio_processor, text_tokenizer) # Instantiate the model and processor config = ultravox_config.UltravoxConfig( @@ -128,6 +120,15 @@ def train(args: config_base.TrainConfig): logging.info("Instantiating model...") model = ultravox_model.UltravoxModel(config) + # loss_config needs to be passed separately just for model training + if args.loss_config is not None: + model.set_loss_config(args.loss_config) + + audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model) + processor = ultravox_processing.UltravoxProcessor( + audio_processor=audio_processor, tokenizer=text_tokenizer + ) + assert model.get_input_embeddings().num_embeddings == len( text_tokenizer ), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}" @@ -138,10 +139,6 @@ def train(args: config_base.TrainConfig): # https://github.com/huggingface/transformers/issues/17116#issuecomment-1121340890 model.audio_tower.config.layerdrop = 0.0 - # loss_config needs to be passed separately just for model training - if args.loss_config is not None: - model.set_loss_config(args.loss_config) - logging.info("Model and processor instantiated.") # Starting W&B. HF Trainer can also do this, but this way we can include the config. @@ -164,6 +161,7 @@ def train(args: config_base.TrainConfig): # and hence this is just resolving the path. If the weights are not downloaded, # we might see a race condition here when using DDP. load_path = wandb_utils.download_model_from_wandb(load_path) + if os.path.isdir(load_path): load_path = os.path.join(load_path, "model*.safetensors") paths = glob.glob(load_path) @@ -189,66 +187,41 @@ def train(args: config_base.TrainConfig): # Prepare dataset, subsetting if needed train_dataset: data.IterableDataset val_datasets: Dict[str, data.IterableDataset] - # We use multiple validation sets here so that the results are comparable even when training set changes - # To make sure we can compare training and validation loss (e.g. for fine-tuning), we keep a special set - # called "matchtrain" that uses the same data as the training set. - val_sets = dict( - # [("matchtrain", args.data_sets)] # FIXME: see issue https://github.com/fixie-ai/ultravox/issues/58 - [(x, [x]) for x in args.val_sets] - + [(f"text_{x}", [x]) for x in args.val_sets] - ) + train_dataset = prepare_dataset( train_args=args, - dataset_names=args.data_sets, + dataset_configs=args.train_dataset_configs, train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, - num_samples=args.num_samples, - data_args=datasets.VoiceDatasetArgs( - num_prompts=args.num_prompts, - data_dir=args.data_dir, - shuffle=args.shuffle_data, - shuffle_seed=args.shuffle_seed, - max_audio_duration_secs=args.max_audio_duration_secs, - use_mds=args.mds, - mds_batch_size=args.batch_size, - ), + data_args=args.train_dataset_args, include_alt_fields=model.loss_config.requires_alt_fields, ) if is_master: - val_ds_args = datasets.VoiceDatasetArgs( - num_prompts=1, - split=datasets.DatasetSplit.VALIDATION, - data_dir=args.data_dir, - shuffle=False, - max_audio_duration_secs=16, - use_mds=args.mds, - mds_batch_size=args.batch_size, - ) - val_ds_args_text = copy.copy(val_ds_args) - val_ds_args_text.include_audio = False val_datasets = { - k: prepare_dataset( + dataset_config.alias: prepare_dataset( train_args=args, - dataset_names=val_sets[k], + dataset_configs=[dataset_config], train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, processor=processor, - num_samples=args.val_num_samples, - data_args=val_ds_args_text if k.startswith("text_") else val_ds_args, + data_args=args.val_dataset_args, include_alt_fields=model.loss_config.requires_alt_fields, ) - for k in val_sets + for dataset_config in args.val_dataset_configs } logging.info( - f"Loaded {args.data_sets} data sets, sample limit: {args.num_samples} (val sample limit: {args.val_num_samples})" + f"Loaded {len(args.train_dataset_configs)} train data sets, {len(args.val_dataset_configs)} val data sets" ) else: # When using DDP with split_batches=True, the primary process will distribute the batches to the workers # The point of this is to avoid unnecessary data processing/downloading in the workers. # When using epochs to train, emptydataset must have a length equal to the training set - train_dataset = datasets.EmptyDataset(len(train_dataset)) - val_datasets = {k: datasets.EmptyDataset() for k in val_sets} + train_dataset = datasets.EmptyDataset() + val_datasets = { + dataset_config.alias: datasets.EmptyDataset() + for dataset_config in args.val_dataset_configs + } # Set up the data loader data_collator = datasets.DataCollatorForSeq2SeqWithAudio( @@ -279,7 +252,8 @@ def train(args: config_base.TrainConfig): logging_steps=args.logging_steps, # TODO (Farzad): reconsider for multi-node # In DDP world_size is set to num_gpus and we want process-0 to split the batches - per_device_train_batch_size=args.batch_size * world_size, + per_device_train_batch_size=args.train_dataset_args.batch_size * world_size, + per_device_eval_batch_size=args.val_dataset_args.batch_size * world_size, accelerator_config={"split_batches": True}, gradient_accumulation_steps=args.grad_accum_steps, eval_accumulation_steps=args.val_accum_steps, @@ -312,101 +286,46 @@ def train(args: config_base.TrainConfig): logging.info(f"train end time: {t_end}") logging.info(f"elapsed: {t_end - t_start}") - if is_master: - # Saving the model using pipeline to ensure its code is saved - pipeline = ultravox_pipeline.UltravoxPipeline( - model, tokenizer=text_tokenizer, device=device + if args.do_eval: + logging.info("Starting evaluation...") + t_start = datetime.now() + logging.info(f"eval start time: {t_start}") + + # Merge LoRA weights for better inference performance. + # Note: this is irreversible and changes model saving format + model.merge_and_unload() + # changing padding side to left for inference + text_tokenizer.padding_side = "left" + inference = infer.LocalInference( + model=model, + processor=processor, + tokenizer=text_tokenizer, + device=args.device, + dtype=dtype, ) - pipeline.save_pretrained(args.output_dir) - - -def evaluate(args: config_base.TrainConfig): - """ - Evaluate the model on the audio and text datasets. - NOTE: This function must be run only on the primary process. - """ - logging.info("Starting evaluation...") - t_start = datetime.now() - logging.info(f"eval start time: {t_start}") - - if args.text_model_lora_config and args.text_model_lora_config.r: - logging.warn( - "Model has unmerged LoRA config. This can lead to slower inference." + metrics, output_files = eval.run_infer( + inference, + args.eval_dataset_args, + args.eval_dataset_configs, + world_size, + local_rank, ) + if is_master: + trainer.log(metrics) + for output_file in output_files: + wandb.save(output_file) - logs_dir = wandb.run.dir if wandb.run else str(args.logs_dir) + t_end = datetime.now() + logging.info(f"eval end time: {t_end}") + logging.info(f"elapsed: {t_end - t_start}") - # Run audio-based evaluations and log to W&B - audio_metrics_df = run_oaievalset( - log_dir=os.path.join(logs_dir, "oaieval/audio"), - model_dir=str(args.output_dir), - eval_set="audio-core", - num_samples=args.eval_num_samples, - ) - # TODO: it would be best to do trainer.log, but then we'd risk keeping parts of the model - # in GPU memory, which could cause OOM errors. - if wandb.run: - wandb.run.log({"eval_audio": wandb.Table(data=audio_metrics_df)}) - - if args.eval_text_only: - # Run text-only evaluations and log to W&B - text_metrics_df = run_oaievalset( - log_dir=os.path.join(logs_dir, "oaieval/text"), - model_dir=str(args.output_dir), - eval_set="transcript-core", - num_samples=args.eval_num_samples, + if is_master: + # Saving the model using pipeline to ensure its code is saved + pipeline = ultravox_pipeline.UltravoxPipeline( + model, tokenizer=text_tokenizer, device=device ) - if wandb.run: - wandb.run.log({"eval_text": wandb.Table(data=text_metrics_df)}) - - t_end = datetime.now() - logging.info(f"eval end time: {t_end}") - logging.info(f"elapsed: {t_end - t_start}") - - -def run_oaievalset( - log_dir: str, model_dir: str, eval_set: str, num_samples: Optional[int] = None -) -> pd.DataFrame: - env = os.environ.copy() - - # num_gpus = max(1, torch.cuda.device_count()) - env["EVALS_THREADS"] = "64" - - # TODO: currently running this on a single GPU is faster than multiple GPUs :facepalm: - env["CUDA_VISIBLE_DEVICES"] = "0" - - command = [ - "oaievalset", - "--record_dir", - log_dir, - "generation/gpu/ultravox-dev", - eval_set, - f"--completion_args=model={model_dir}", - ] - if num_samples: - command.append(f"--max_samples={num_samples}") - - # Run the evaluation set - subprocess.run(command, check=True, env=env) - - # Extract the results from the log directory - subprocess.run( - [ - "python", - "-m", - "evals.elsuite.audio.make_table", - "--out_dir", - log_dir, - "--log_dir", - log_dir, - ], - check=True, - ) - - df = pd.read_csv(os.path.join(log_dir, "results.csv")) - - return df + pipeline.save_pretrained(args.output_dir) if __name__ == "__main__": diff --git a/ultravox/utils/device_helpers.py b/ultravox/utils/device_helpers.py new file mode 100644 index 00000000..95931f8e --- /dev/null +++ b/ultravox/utils/device_helpers.py @@ -0,0 +1,49 @@ +import os +from typing import Optional + +import torch + + +def default_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + +def default_dtype(): + # macOS Sonoma 14 enabled bfloat16 on MPS. + return ( + torch.bfloat16 + if torch.cuda.is_available() or torch.backends.mps.is_available() + else torch.float16 + ) + + +def default_dtype_str(): + return ( + "bfloat16" + if torch.cuda.is_available() or torch.backends.mps.is_available() + else "float16" + ) + + +def get_dtype(data_type: Optional[str] = None): + if data_type is None: + return default_dtype() + else: + return ( + torch.bfloat16 + if data_type == "bfloat16" + else torch.float16 if data_type == "float16" else torch.float32 + ) + + +def get_world_size(): + return int(os.environ.get("WORLD_SIZE", 1)) + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0)) diff --git a/ultravox/utils/string_helpers.py b/ultravox/utils/string_helpers.py new file mode 100644 index 00000000..6793ca8c --- /dev/null +++ b/ultravox/utils/string_helpers.py @@ -0,0 +1,59 @@ +import re +import sys +import unicodedata + + +def fix_hyphens(arg: str): + return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) + +def sanitize_name(name: str) -> str: + """ + Sanitize a string to be safely used as a valid filename, folder name, or Python module name. + + Parameters: + name (str): The input string to sanitize. + + Returns: + str: The sanitized name. + """ + # Remove leading/trailing whitespace and normalize unicode characters + name = unicodedata.normalize('NFKC', name.strip()) + + # Define invalid characters based on the operating system + if sys.platform.startswith('win'): + invalid_chars = r'[<>:"/\\|?*\x00-\x1F]' + reserved_names = { + 'CON', 'PRN', 'AUX', 'NUL', + *(f'COM{i}' for i in range(1, 10)), + *(f'LPT{i}' for i in range(1, 10)), + } + else: + invalid_chars = r'[<>:"/\\|?*\x00]' + reserved_names = set() + + # Replace invalid characters with underscores + name = re.sub(invalid_chars, '_', name) + + # Remove trailing periods and spaces (Windows limitation) + name = name.rstrip('. ') + + # Handle reserved device names (Windows) + name_part, dot, extension = name.partition('.') + if name_part.upper() in reserved_names: + name_part = f'_{name_part}' + name = name_part + (dot + extension if dot else '') + + # Ensure the name is a valid Python identifier (module name) + # Replace invalid characters with underscores + name = re.sub(r'\W|^(?=\d)', '_', name) + # Ensure it doesn't start with a digit + if not re.match(r'[A-Za-z_]', name): + name = f'_{name}' + # Remove any leading or trailing underscores + name = name.strip('_') + + # If the result is empty, return a default name + if not name: + name = 'default_name' + + return name \ No newline at end of file From 6541b7ab924b49c61760921ecf6810c370449394 Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Mon, 16 Sep 2024 19:32:33 -0400 Subject: [PATCH 2/4] Update --- ultravox/data/datasets.py | 64 ++++++++++++++++++--------- ultravox/data/datasets_test.py | 26 ++++++----- ultravox/evaluation/eval_types.py | 1 + ultravox/evaluation/gpt_eval_test.py | 1 + ultravox/inference/infer.py | 34 ++++++-------- ultravox/model/ultravox_processing.py | 6 ++- ultravox/training/config_base.py | 10 ++--- ultravox/training/train.py | 1 - ultravox/utils/string_helpers.py | 40 +++++++++-------- 9 files changed, 104 insertions(+), 79 deletions(-) diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 93e9951c..59c4e616 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -180,14 +180,20 @@ def model_post_init(self, __context: Any) -> None: raise ValueError( "num_samples must be non-negative for determing dataset size in all cases (streaming or not)" ) - + # Ensure the audio field is always included - if self.audio_field is not None and self.audio_field not in self.base_audio_columns: + if ( + self.audio_field is not None + and self.audio_field not in self.base_audio_columns + ): self.base_audio_columns.append(self.audio_field) @dataclasses.dataclass class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): + # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel + include_alt_fields: bool = False + def __call__(self, features, *args, **kwargs): audio_values = [f.pop("audio_values", None) for f in features] if self.include_alt_fields: @@ -374,10 +380,6 @@ def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: f"Created VoiceDataset with config:\n{self._config.model_dump_json(indent=2)}" ) - # _dataset and _length are initialized in _init_dataset - self._dataset = None - self._length = 0 - def _init_dataset(self, dataset: data.Dataset, num_samples: int) -> None: self._dataset = dataset self._length = num_samples @@ -388,7 +390,7 @@ def weight(self) -> float: def __len__(self): return self._length - + def _load_audio_dataset( self, path: str, @@ -420,8 +422,12 @@ def _load_audio_dataset( else: # HF datasets sometimes fails to download due to network issues, so retry a few times. dataset = datasets.load_dataset( - path, name, split=split, trust_remote_code=True, streaming=streaming, - download_config=datasets.DownloadConfig(max_retries=10) + path, + name, + split=split, + trust_remote_code=True, + streaming=streaming, + download_config=datasets.DownloadConfig(max_retries=10), ) for column_name in self._config.base_audio_columns: dataset = dataset.cast_column( @@ -465,11 +471,11 @@ def __iter__(self): actual_length += 1 if actual_length == len(self) + 1: warnings.warn( - f"The actual number of samples ({actual_length}) has exceeded the presumed length ({self._length}) for dataset {self._config.alias}. Make sure to update." + f"The presumed length {self._length} has been exceeded for dataset {self._config.alias}. Make sure to update." ) if actual_length != len(self): warnings.warn( - f"Mismatch between actual length ({actual_length}) and presumed length ({self._length}) for dataset {self._config.alias}. Make sure to update." + f"Mismatch between presumed length ({self._length}) and actual length ({actual_length}) for dataset {self._config.alias}. Make sure to update." ) @abc.abstractmethod @@ -505,7 +511,7 @@ def _get_transcribe_messages(self, text: str) -> List[Dict[str, str]]: return _get_messages(prompt, text) def _get_audio( - self, row: transformers.BatchFeature, column_name: str = "audio" + self, row: transformers.BatchFeature, column_name: Optional[str] = "audio" ) -> np.ndarray: if column_name not in self._config.base_audio_columns: raise ValueError( @@ -602,12 +608,9 @@ def __init__(self, args: VoiceDatasetArgs, config: DatasetConfig) -> None: dataset = dataset.shuffle(seed=self._args.shuffle_seed) if config.active_samples: - dataset = Range( - SizedIterableDataset(dataset, config.num_samples), config.active_samples - ) + dataset = Range(dataset, config.active_samples, config.num_samples) num_samples = config.active_samples - # super().__init__(args, config, dataset, num_samples) super()._init_dataset(dataset, num_samples) def _get_sample(self, row) -> Optional[VoiceSample]: @@ -648,6 +651,9 @@ class EmptyDataset(SizedIterableDataset): def __iter__(self): return iter([]) + def __len__(self): + return 0 + class AnyInstructDataset(VoiceDataset): """ @@ -1082,18 +1088,29 @@ class Range(SizedIterableDataset): def __init__( self, - dataset: SizedIterableDataset, + dataset: data.IterableDataset, + active_samples: Optional[int] = None, num_samples: Optional[int] = None, ) -> None: self._dataset = dataset - if num_samples is None: - self._length = len(dataset) + + if isinstance(dataset, SizedIterableDataset): + if num_samples is not None and num_samples != len(dataset): + raise ValueError( + f"The presumed num_samples ({num_samples}) is not equal to the number of samples in the dataset ({len(dataset)}). Please update." + ) + num_samples = len(dataset) else: - if num_samples > len(dataset): + if num_samples is None: raise ValueError( - f"num_samples {num_samples} is greater than the number of samples in the dataset {len(dataset)}" + "num_samples must be provided for non-SizedIterableDatasets" ) - self._length = num_samples + + self._length = active_samples or num_samples + if self._length > num_samples: + raise ValueError( + f"The number of samples to limit to ({self._length}) is greater than the number of samples in the dataset ({num_samples}). Please update." + ) def __iter__(self): count = 0 @@ -1107,3 +1124,6 @@ def __iter__(self): raise ValueError( f"Dataset exhausted after {count} samples, expected {self._length}" ) + + def __len__(self): + return self._length diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 6cfea7f0..cced82b5 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -14,10 +14,10 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset): """Fake version of datasets.SizedIterableDataset""" - def __init__(self, n, start=0, weight=1, estimated_length=0): + def __init__(self, n, start=0, weight=1, length=0): self.data = range(start, start + n) self._weight = weight - self._estimated_length = estimated_length + self._length = length @property def weight(self) -> float: @@ -28,7 +28,7 @@ def __iter__(self): yield sample def __len__(self): - return self._estimated_length + return self._length class FakeHuggingFaceIterableDataset(hf_datasets.IterableDataset): @@ -58,7 +58,7 @@ def __init__( ): super().__init__( args or datasets.VoiceDatasetArgs(), - config or datasets.DatasetConfig(num_samples=n), + config or datasets.DatasetConfig(num_samples=n, splits=["fake"]), ) self._init_dataset(FakeHuggingFaceIterableDataset(n), n) @@ -185,11 +185,13 @@ def test_interleaved_with_multiprocessing(): def test_range(): - ds = FakeSizedIterableDataset(10, estimated_length=10) + ds = FakeSizedIterableDataset(10, length=10) s = datasets.Range(ds, 5) assert len(s) == 5 assert list(s) == [0, 1, 2, 3, 4] - s = datasets.Range(ds, 100) + with pytest.raises(ValueError, match="is greater than the number of samples"): + s = datasets.Range(ds, 100) + s = datasets.Range(ds, 10) assert list(s) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] s = datasets.Range(ds) assert len(s) == 10 @@ -331,17 +333,21 @@ def test_get_messages(): def test_voice_dataset_size(): - mock_config = datasets.DatasetConfig(path="mock_path", total_samples=5) + mock_config = datasets.DatasetConfig( + path="mock_path", num_samples=5, splits=["fake"] + ) ds = FakeGenericDataset(10, mock_config) assert len(ds) == 5 - pattern = r"(has been exceeded|Mismatch between estimated length)" + pattern = r"(has been exceeded|Mismatch between presumed length)" with pytest.warns(UserWarning, match=pattern): list(ds) - mock_config = datasets.DatasetConfig(path="mock_path", total_samples=10) + mock_config = datasets.DatasetConfig( + path="mock_path", num_samples=10, splits=["fake"] + ) ds = FakeGenericDataset(5, mock_config) assert len(ds) == 10 - with pytest.warns(UserWarning, match="Mismatch between estimated length"): + with pytest.warns(UserWarning, match="Mismatch between presumed length"): list(ds) diff --git a/ultravox/evaluation/eval_types.py b/ultravox/evaluation/eval_types.py index 7a28edb8..0e988eca 100644 --- a/ultravox/evaluation/eval_types.py +++ b/ultravox/evaluation/eval_types.py @@ -4,6 +4,7 @@ import dataclasses_json from pydantic import BaseModel + # Eval config for a single metric, added to the dataset config class EvalConfig(BaseModel): metric: str diff --git a/ultravox/evaluation/gpt_eval_test.py b/ultravox/evaluation/gpt_eval_test.py index c185048a..9575bd0d 100644 --- a/ultravox/evaluation/gpt_eval_test.py +++ b/ultravox/evaluation/gpt_eval_test.py @@ -9,6 +9,7 @@ def test_evaluate_conversation(): gpt_eval.client = mock.MagicMock() sample = eval_types.Sample( + index=0, history=[ {"role": "system", "content": "Blah blah blah"}, {"role": "user", "content": "T1"}, diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index 28e50ead..e7c770b7 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -38,6 +38,7 @@ def __init__( ) self.data_collator = datasets.DataCollatorForSeq2SeqWithAudio( tokenizer=self.tokenizer, + include_alt_fields=False, ) assert self.tokenizer.padding_side == "left" @@ -60,18 +61,18 @@ def _get_sample_with_past( def _build_past_messages( self, query_messages: List[Dict[str, str]], - num_audio_tokens: int, + audio_token_len: int, response_content: str, ) -> List[Dict[str, str]]: messages = copy.copy(query_messages) - if num_audio_tokens > 0: + if audio_token_len > 0: user_content = messages[-1]["content"] if user_content.count("<|audio|>") != 1: raise ValueError( f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}" ) messages[-1]["content"] = user_content.replace( - "<|audio|>", self.tokenizer.eos_token * num_audio_tokens + "<|audio|>", self.tokenizer.eos_token * audio_token_len ) messages.append({"role": "assistant", "content": response_content}) return messages @@ -92,11 +93,9 @@ def infer( output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True) output_len = len(output_tokens) if self.conversation_mode: - num_audio_tokens = ( - output.past_key_values[0][0].shape[2] - input_len - output_len - ) + audio_token_len = inputs.get("audio_token_len", [0])[0] past_messages = self._build_past_messages( - extended_sample.messages, num_audio_tokens, output_text + extended_sample.messages, audio_token_len, output_text ) self.update_conversation(past_messages, output.past_key_values) return base.VoiceOutput(output_text, input_len, output_len) @@ -142,7 +141,7 @@ def infer_stream( ) -> base.InferenceGenerator: extended_sample = self._get_sample_with_past(sample) inputs = self._dataproc(extended_sample) - input_len = inputs["input_ids"].shape[1] + input_tokens = inputs["input_ids"].shape[1] streamer = transformers.TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) @@ -159,24 +158,21 @@ def thunk(f: futures.Future): thread = threading.Thread(target=thunk, args=(future,)) thread.start() output_text = "" + output_token_len = 0 for chunk in streamer: if chunk: output_text += chunk + output_token_len += 1 yield base.InferenceChunk(chunk) thread.join() - output_len = len( - self.tokenizer(output_text, add_special_tokens=False).input_ids - ) output = future.result() if self.conversation_mode: - num_audio_tokens = ( - output.past_key_values[0][0].shape[2] - input_len - output_len - ) + audio_token_len = inputs.get("audio_token_len", [0])[0] past_messages = self._build_past_messages( - extended_sample.messages, num_audio_tokens, output_text + extended_sample.messages, audio_token_len, output_text ) self.update_conversation(past_messages, output.past_key_values) - yield base.InferenceStats(input_len, output_len) + yield base.InferenceStats(input_tokens, output_token_len) def _dataproc(self, sample: datasets.VoiceSample): text_input = self.tokenizer.apply_chat_template( @@ -228,10 +224,8 @@ def _generate( do_sample = temperature is not None terminators = [self.tokenizer.eos_token_id] - if "" in self.tokenizer.added_tokens_encoder: - terminators.append( - self.tokenizer.convert_tokens_to_ids("<|end_header_id|>") - ) + if "<|eot_id|>" in self.tokenizer.added_tokens_encoder: + terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) return self.model.generate( **inputs, diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 211f7f0a..68260344 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -63,8 +63,10 @@ def __init__( @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config: UltravoxConfig = transformers.AutoConfig.from_pretrained( - pretrained_model_name_or_path, **kwargs + config: UltravoxConfig = ( # type: ignore + transformers.AutoConfig.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) ) audio_processor = transformers.AutoProcessor.from_pretrained( config.audio_model_id diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 6568e418..eced8fe6 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -2,7 +2,6 @@ import datetime import logging import os -import re import sys from pathlib import Path from typing import Any, Dict, List, Optional @@ -11,7 +10,8 @@ from ultravox.data import datasets from ultravox.model import ultravox_config -from ultravox.utils import device_helpers, string_helpers +from ultravox.utils import device_helpers +from ultravox.utils import string_helpers @dataclasses.dataclass @@ -24,9 +24,6 @@ class TrainConfig: audio_model: str """Audio encoder model to use; could be a huggingface model id, wandb path, or local path.""" - adapter_type: ultravox_config.AdapterType = ultravox_config.AdapterType.STACKING - """Type of adapter to use for fine-tuning. Defaults to STACKING.""" - adapter_config: Optional[Dict[str, Any]] = None """Optional configuration dictionary for the adapter. If None, default settings will be used.""" @@ -220,6 +217,7 @@ def __post_init__(self): ) self.disable_layerdrop = True + def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig: """ Parse the command line arguments and return a TrainConfig object. @@ -235,4 +233,4 @@ def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"), add_config_path_arg=True, args=[string_helpers.fix_hyphens(arg) for arg in args], - ) \ No newline at end of file + ) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 7e26c639..81bbdee2 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -2,7 +2,6 @@ import glob import logging import os -import sys from datetime import datetime from typing import Dict, List diff --git a/ultravox/utils/string_helpers.py b/ultravox/utils/string_helpers.py index 6793ca8c..e337993a 100644 --- a/ultravox/utils/string_helpers.py +++ b/ultravox/utils/string_helpers.py @@ -6,6 +6,7 @@ def fix_hyphens(arg: str): return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) + def sanitize_name(name: str) -> str: """ Sanitize a string to be safely used as a valid filename, folder name, or Python module name. @@ -17,43 +18,46 @@ def sanitize_name(name: str) -> str: str: The sanitized name. """ # Remove leading/trailing whitespace and normalize unicode characters - name = unicodedata.normalize('NFKC', name.strip()) + name = unicodedata.normalize("NFKC", name.strip()) # Define invalid characters based on the operating system - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): invalid_chars = r'[<>:"/\\|?*\x00-\x1F]' - reserved_names = { - 'CON', 'PRN', 'AUX', 'NUL', - *(f'COM{i}' for i in range(1, 10)), - *(f'LPT{i}' for i in range(1, 10)), + reserved_names: set[str] = { + "CON", + "PRN", + "AUX", + "NUL", + *(f"COM{i}" for i in range(1, 10)), + *(f"LPT{i}" for i in range(1, 10)), } else: invalid_chars = r'[<>:"/\\|?*\x00]' - reserved_names = set() + reserved_names: set[str] = set() # Replace invalid characters with underscores - name = re.sub(invalid_chars, '_', name) + name = re.sub(invalid_chars, "_", name) # Remove trailing periods and spaces (Windows limitation) - name = name.rstrip('. ') + name = name.rstrip(". ") # Handle reserved device names (Windows) - name_part, dot, extension = name.partition('.') + name_part, dot, extension = name.partition(".") if name_part.upper() in reserved_names: - name_part = f'_{name_part}' - name = name_part + (dot + extension if dot else '') + name_part = f"_{name_part}" + name = name_part + (dot + extension if dot else "") # Ensure the name is a valid Python identifier (module name) # Replace invalid characters with underscores - name = re.sub(r'\W|^(?=\d)', '_', name) + name = re.sub(r"\W|^(?=\d)", "_", name) # Ensure it doesn't start with a digit - if not re.match(r'[A-Za-z_]', name): - name = f'_{name}' + if not re.match(r"[A-Za-z_]", name): + name = f"_{name}" # Remove any leading or trailing underscores - name = name.strip('_') + name = name.strip("_") # If the result is empty, return a default name if not name: - name = 'default_name' + name = "default_name" - return name \ No newline at end of file + return name From 22ba1f80da7eeb309f56127eb3219d56cdedff1f Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Mon, 16 Sep 2024 19:34:58 -0400 Subject: [PATCH 3/4] Update --- mcloud_eval.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mcloud_eval.yaml b/mcloud_eval.yaml index c3fbf542..8acdbcd8 100644 --- a/mcloud_eval.yaml +++ b/mcloud_eval.yaml @@ -6,17 +6,17 @@ compute: cluster: r14z3p1 integrations: - integration_type: git_repo - git_repo: fixie-ai/ultravox-expts + git_repo: fixie-ai/ultravox git_branch: $UV_BRANCH pip_install: poetry==1.7.1 scheduling: max_duration: 2 # 2 hours max for jobs to avoid hanging jobs command: >- - cd ultravox-expts && + cd ultravox && poetry install --no-dev && poetry run python -m ultravox.training.helpers.prefetch_weights $EVAL_ARGS && poetry run torchrun --nproc_per_node=8 -m ultravox.tools.eval_tool $EVAL_ARGS --exp_name $UV_BRANCH env_variables: MLFLOW_TRACKING_URI: databricks - UV_BRANCH: zhuang.2024-08-12-ultravox.merge_p1 + UV_BRANCH: zhuang/unified_dataset_specification EVAL_ARGS: --config_path ultravox/evaluation/configs/eval_config_2k.yaml From fcb81fffb0a469867eb9e9ec2850e66822352075 Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Mon, 16 Sep 2024 19:54:20 -0400 Subject: [PATCH 4/4] Update --- mcloud_eval.yaml | 4 +- mcloud_train.yaml | 10 ++-- ultravox/data/datasets.py | 2 +- ultravox/evaluation/configs/eval_config.yaml | 2 +- .../evaluation/configs/eval_config_2k.yaml | 2 +- .../evaluation/configs/eval_config_tiny.yaml | 2 +- ultravox/tools/eval_tool.py | 59 +------------------ ultravox/training/config_base.py | 4 +- ultravox/training/helpers/prefetch_weights.py | 6 +- 9 files changed, 21 insertions(+), 70 deletions(-) diff --git a/mcloud_eval.yaml b/mcloud_eval.yaml index 8acdbcd8..36d8c4c6 100644 --- a/mcloud_eval.yaml +++ b/mcloud_eval.yaml @@ -15,8 +15,8 @@ command: >- cd ultravox && poetry install --no-dev && poetry run python -m ultravox.training.helpers.prefetch_weights $EVAL_ARGS && - poetry run torchrun --nproc_per_node=8 -m ultravox.tools.eval_tool $EVAL_ARGS --exp_name $UV_BRANCH + poetry run torchrun --nproc_per_node=8 -m ultravox.tools.eval_tool $EVAL_ARGS env_variables: MLFLOW_TRACKING_URI: databricks - UV_BRANCH: zhuang/unified_dataset_specification + UV_BRANCH: main EVAL_ARGS: --config_path ultravox/evaluation/configs/eval_config_2k.yaml diff --git a/mcloud_train.yaml b/mcloud_train.yaml index 53fb24d0..fcafbe11 100644 --- a/mcloud_train.yaml +++ b/mcloud_train.yaml @@ -6,17 +6,17 @@ compute: cluster: r14z3p1 integrations: - integration_type: git_repo - git_repo: fixie-ai/ultravox-expts + git_repo: fixie-ai/ultravox git_branch: $UV_BRANCH pip_install: poetry==1.7.1 scheduling: max_duration: 6 # 6 hours max for jobs to avoid hanging jobs command: >- - cd ultravox-expts && + cd ultravox && poetry install --no-dev && poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && - poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS --exp_name $UV_BRANCH + poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS env_variables: MLFLOW_TRACKING_URI: databricks - UV_BRANCH: zhuang.2024-08-12-ultravox.merge_p1 - TRAIN_ARGS: --config_path ultravox/training/configs/expt_config.yaml \ No newline at end of file + UV_BRANCH: main + TRAIN_ARGS: --config_path ultravox/training/configs/release_config.yaml \ No newline at end of file diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 59c4e616..49c68e7f 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -1105,7 +1105,7 @@ def __init__( raise ValueError( "num_samples must be provided for non-SizedIterableDatasets" ) - + self._length = active_samples or num_samples if self._length > num_samples: raise ValueError( diff --git a/ultravox/evaluation/configs/eval_config.yaml b/ultravox/evaluation/configs/eval_config.yaml index 53c866a9..7bbcfa6d 100644 --- a/ultravox/evaluation/configs/eval_config.yaml +++ b/ultravox/evaluation/configs/eval_config.yaml @@ -1,4 +1,4 @@ -model_path: "fixie-ai/ultravox-v0_4" +model_load_dir: "fixie-ai/ultravox-v0_4" report_logs_to: ["wandb"] output_dir: "./output" diff --git a/ultravox/evaluation/configs/eval_config_2k.yaml b/ultravox/evaluation/configs/eval_config_2k.yaml index b69305b5..17fbdd93 100644 --- a/ultravox/evaluation/configs/eval_config_2k.yaml +++ b/ultravox/evaluation/configs/eval_config_2k.yaml @@ -1,4 +1,4 @@ -model_path: "fixie-ai/ultravox-v0_4" +model_load_dir: "fixie-ai/ultravox-v0_4" report_logs_to: ["wandb"] output_dir: "./output" diff --git a/ultravox/evaluation/configs/eval_config_tiny.yaml b/ultravox/evaluation/configs/eval_config_tiny.yaml index 963a93bd..22e6c83b 100644 --- a/ultravox/evaluation/configs/eval_config_tiny.yaml +++ b/ultravox/evaluation/configs/eval_config_tiny.yaml @@ -1,4 +1,4 @@ -model_path: "fixie-ai/ultravox-v0_4" +model_load_dir: "fixie-ai/ultravox-v0_4" report_logs_to: ["wandb"] output_dir: "./output" diff --git a/ultravox/tools/eval_tool.py b/ultravox/tools/eval_tool.py index 82a98c6f..3e20eb6d 100644 --- a/ultravox/tools/eval_tool.py +++ b/ultravox/tools/eval_tool.py @@ -1,20 +1,17 @@ #!/usr/bin/env python import dataclasses -import datetime import logging import os import sys -from pathlib import Path -from typing import List, Optional import simple_parsing import torch import torch.distributed as dist import wandb -from ultravox.data import datasets from ultravox.evaluation import eval from ultravox.inference import ultravox_infer +from ultravox.training import config_base from ultravox.training import ddp_utils from ultravox.utils import device_helpers from ultravox.utils import string_helpers @@ -22,59 +19,9 @@ logging.basicConfig(level=logging.INFO) -@dataclasses.dataclass -class EvalArgs: - model_path: str = simple_parsing.field() - """Model ID to use for the model""" - - eval_dataset_configs: List[datasets.DatasetConfig] = dataclasses.field( - default_factory=list - ) - """List of evaluation dataset configurations""" - - eval_dataset_args: datasets.VoiceDatasetArgs = dataclasses.field( - default_factory=datasets.VoiceDatasetArgs - ) - """Global arguments for the evaluation dataset""" - - device: str = device_helpers.default_device() - """Device to use for training (e.g., 'cuda', 'cpu', 'mps')""" - - data_type: str = device_helpers.default_dtype_str() - """Data type to use for training (e.g., 'bfloat16', 'float16', 'float32')""" - - exp_name: Optional[str] = simple_parsing.field(default=None) - """The experiment name""" - - output_dir: Optional[Path] = simple_parsing.field(default=None) - """Output directory""" - - report_logs_to: List[str] = simple_parsing.field(default_factory=list) - """Whether to report results to wandb""" - - def __post_init__(self): - self.eval_dataset_configs = [ - datasets.DatasetConfig(**config) for config in self.eval_dataset_configs - ] - - if self.data_type not in ["bfloat16", "float16", "float32", None]: - raise ValueError( - f"Invalid data type: {self.data_type}. Please specify one of the following: bfloat16, float16, float32." - ) - if self.device is None: - self.device = device_helpers.default_device() - - if self.output_dir is None: - self.output_dir = Path("runs") / self.exp_name - self.output_dir = Path(self.output_dir) - - if self.exp_name is None: - self.exp_name = datetime.datetime.now().strftime("exp--%Y-%m-%d--%H-%M-%S") - - def main(): args = simple_parsing.parse( - config_class=EvalArgs, + config_class=config_base.TrainConfig, add_config_path_arg=True, args=[string_helpers.fix_hyphens(arg) for arg in sys.argv[1:]], ) @@ -97,7 +44,7 @@ def main(): with ddp_utils.run_on_master_first(local_rank == 0): inference = ultravox_infer.UltravoxInference( - args.model_path, + args.model_load_dir, device=device, data_type=device_helpers.get_dtype(args.data_type), ) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index eced8fe6..08ba0fa1 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -18,10 +18,10 @@ class TrainConfig: """Configuration class for training settings.""" - text_model: str + text_model: Optional[str] = None """Language model to use; could be a huggingface model id, wandb path, or local path.""" - audio_model: str + audio_model: Optional[str] = None """Audio encoder model to use; could be a huggingface model id, wandb path, or local path.""" adapter_config: Optional[Dict[str, Any]] = None diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 30449a38..55a2b64b 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -22,8 +22,12 @@ def main(override_sys_args: Optional[List[str]] = None): print(f"Weights downloaded in {end - start} seconds") -def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None): +def download_weights( + model_ids: List[Optional[str]], model_load_dir: Optional[str] = None +): for model_id in model_ids: + if model_id is None: + continue try: # Download all model files that match ALLOW_PATTERNS # This is faster than .from_pretrained due to parallel downloads