diff --git a/setup.cfg b/setup.cfg index 6212e8cb33..e91bf0bc9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,8 +54,6 @@ install_requires= scipy~=1.10 uncertainty-calibration~=0.1.4 scikit-learn~=1.1 - jiwer~=3.0 - rapidfuzz~=3.10 # Models and Metrics Extras transformers~=4.40 # For anthropic_client, vision_language.huggingface_vlm_client, huggingface_client, huggingface_tokenizer, test_openai_token_cost_estimator, model_summac (via summarization_metrics) @@ -293,6 +291,7 @@ audiolm = pycocoevalcap~=1.2 jiwer~=3.0 rapidfuzz~=3.10 + jieba~=0.42.1 # Install everything all = diff --git a/src/helm/benchmark/metrics/evaluate_reference_metrics.py b/src/helm/benchmark/metrics/evaluate_reference_metrics.py index 994f049cde..45cc447835 100644 --- a/src/helm/benchmark/metrics/evaluate_reference_metrics.py +++ b/src/helm/benchmark/metrics/evaluate_reference_metrics.py @@ -9,7 +9,6 @@ from nltk.translate.bleu_score import sentence_bleu from rouge_score import rouge_scorer import numpy as np -from jiwer import wer, mer, wip, cer from helm.benchmark.adaptation.adapter_spec import AdapterSpec from helm.benchmark.adaptation.request_state import RequestState @@ -208,6 +207,11 @@ def wa_score(gold: str, pred: str) -> float: # metric used to evaluate the accuracy of speech recognition systems. # Note that this metric could be negative because the WER might be greater than 1. # https://huggingface.co/learn/audio-course/en/chapter5/evaluation#word-error-rate + try: + from jiwer import wer + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + if not pred: return 0 @@ -218,6 +222,11 @@ def wa_score(gold: str, pred: str) -> float: def ma_score(gold: str, pred: str) -> float: # Match Accuracy (MA) equals to 1 - match error rate (MER), which is for evaluating the accuracy of # speech recognition systems. + try: + from jiwer import mer + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + if not pred: return 0 mer_ret = mer(gold, pred) @@ -227,6 +236,11 @@ def ma_score(gold: str, pred: str) -> float: def wip_score(gold: str, pred: str) -> float: # Word information preservation (WIP) for evaluating the preserved information of speech # recognition systems. + try: + from jiwer import wip + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + if not pred: return 0 wip_ret = wip(gold, pred) @@ -236,12 +250,53 @@ def wip_score(gold: str, pred: str) -> float: def ca_score(gold: str, pred: str) -> float: # Character accuracy (CA) equals to character error rate (CER) for evaluating the accuracy # of speech recognition systems. + try: + from jiwer import cer + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + if not pred: return 0 cer_ret = cer(gold, pred) return cer_ret +def chinese_wa_score(gold: str, pred: str) -> float: + try: + import jieba + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + + return wa_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred))) + + +def chinese_ma_score(gold: str, pred: str) -> float: + try: + import jieba + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + + return ma_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred))) + + +def chinese_wip_score(gold: str, pred: str) -> float: + try: + import jieba + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + + return wip_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred))) + + +def chinese_ca_score(gold: str, pred: str) -> float: + try: + import jieba + except ModuleNotFoundError as e: + handle_module_not_found_error(e, ["audiolm"]) + + return ca_score(" ".join(jieba.cut(gold)), " ".join(jieba.cut(pred))) + + def extract_set_from_text( set_str: str, set_start_str: str = " is ", @@ -395,6 +450,10 @@ def compute_metrics_helper( "ma_score": ma_score, "wip_score": wip_score, "ca_score": ca_score, + "chinese_wa_score": chinese_wa_score, + "chinese_ma_score": chinese_ma_score, + "chinese_wip_score": chinese_wip_score, + "chinese_ca_score": chinese_ca_score, } stats: List[Stat] = [] diff --git a/src/helm/benchmark/run_specs/audio_run_specs.py b/src/helm/benchmark/run_specs/audio_run_specs.py index 50373bccf6..ecd21cec26 100644 --- a/src/helm/benchmark/run_specs/audio_run_specs.py +++ b/src/helm/benchmark/run_specs/audio_run_specs.py @@ -60,6 +60,10 @@ def _get_open_ended_generation_metric_specs() -> List[MetricSpec]: ) +def _get_chinese_audio_recognition_metric_specs() -> List[MetricSpec]: + return get_basic_metric_specs(["chinese_wa_score", "chinese_ma_score", "chinese_wip_score", "chinese_ca_score"]) + + ######################################################################################################################## # RunSpecs @@ -173,7 +177,10 @@ def get_multilingual_librispeech_run_spec(language: str) -> RunSpec: "Respond with only the transcript text.", max_tokens=100, ) - metric_specs = _get_audio_recognition_metric_specs() + if "chinese" in language.lower(): + metric_specs = _get_chinese_audio_recognition_metric_specs() + else: + metric_specs = _get_audio_recognition_metric_specs() return RunSpec( name="multilingual_librispeech", scenario_spec=scenario_spec, @@ -208,7 +215,7 @@ def get_fleurs_run_spec(language: str) -> RunSpec: @run_spec_function("audiocaps") def get_audiocaps_run_spec() -> RunSpec: scenario_spec = ScenarioSpec( - class_name="helm.benchmark.scenarios.audio_language.audiocaps_senario.AudioCapsScenario" + class_name="helm.benchmark.scenarios.audio_language.audiocaps_scenario.AudioCapsScenario" ) adapter_spec = _get_generation_adapter_spec( instructions="Generate a caption for the following audio. The caption should be short and does " @@ -223,3 +230,49 @@ def get_audiocaps_run_spec() -> RunSpec: metric_specs=metric_specs, groups=["audiocaps"], ) + + +@run_spec_function("common_voice_15") +def get_common_voice_15_run_spec(language: str) -> RunSpec: + scenario_spec = ScenarioSpec( + class_name="helm.benchmark.scenarios.audio_language.common_voice_15_scenario.CommonVoice15Scenario", + args={"language": language}, + ) + adapter_spec = _get_generation_adapter_spec( + instructions="Listen to the audio and generate an accurate transcript of the spoken content. " + "Respond with only the transcript text.", + max_tokens=100, + ) + # Chinese characters are not supported in the default metrics + if "chinese" in language.lower(): + metric_specs = _get_chinese_audio_recognition_metric_specs() + else: + metric_specs = _get_audio_recognition_metric_specs() + return RunSpec( + name="common_voice_15", + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + metric_specs=metric_specs, + groups=["common_voice_15"], + ) + + +@run_spec_function("speech_robust_bench") +def get_speech_robust_bench_run_spec(subject: str) -> RunSpec: + scenario_spec = ScenarioSpec( + class_name="helm.benchmark.scenarios.audio_language.speech_robust_bench_scenario.SpeechRobustBenchScenario", + args={"subject": subject}, + ) + adapter_spec = _get_generation_adapter_spec( + instructions="Listen to the audio and generate an accurate transcript of the spoken content. " + "Respond with only the transcript text.", + max_tokens=100, + ) + metric_specs = _get_audio_recognition_metric_specs() + return RunSpec( + name="speech_robust_bench", + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + metric_specs=metric_specs, + groups=["speech_robust_bench"], + ) diff --git a/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py b/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py index eca166d8e0..4b54d82927 100644 --- a/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py +++ b/src/helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py @@ -16,7 +16,7 @@ from datasets import load_dataset from helm.common.media_object import MediaObject, MultimediaObject from helm.common.general import ensure_directory_exists -from helm.common.audio_utils import ensure_wav_file_exists_from_array +from helm.common.audio_utils import ensure_audio_file_exists_from_array class AudioMNISTScenario(Scenario): @@ -53,7 +53,7 @@ def get_instances(self, output_path: str) -> List[Instance]: for row in tqdm(load_dataset("flexthink/audiomnist", cache_dir=output_path, split=TEST_SPLIT)): local_audio_path = os.path.join(wav_save_dir, row["audio"]["path"]) audio_array = row["audio"]["array"] - ensure_wav_file_exists_from_array(local_audio_path, audio_array, row["audio"]["sampling_rate"]) + ensure_audio_file_exists_from_array(local_audio_path, audio_array, row["audio"]["sampling_rate"]) input = Input( multimedia_content=MultimediaObject([MediaObject(content_type="audio/wav", location=local_audio_path)]) ) diff --git a/src/helm/benchmark/scenarios/audio_language/audiocaps_senario.py b/src/helm/benchmark/scenarios/audio_language/audiocaps_scenario.py similarity index 100% rename from src/helm/benchmark/scenarios/audio_language/audiocaps_senario.py rename to src/helm/benchmark/scenarios/audio_language/audiocaps_scenario.py diff --git a/src/helm/benchmark/scenarios/audio_language/common_voice_15_scenario.py b/src/helm/benchmark/scenarios/audio_language/common_voice_15_scenario.py new file mode 100644 index 0000000000..f2513d1223 --- /dev/null +++ b/src/helm/benchmark/scenarios/audio_language/common_voice_15_scenario.py @@ -0,0 +1,98 @@ +"""Scenarios for audio models""" + +from typing import List + +from helm.benchmark.scenarios.scenario import ( + Scenario, + Instance, + Reference, + TEST_SPLIT, + CORRECT_TAG, + Input, + Output, +) +from collections import OrderedDict +from tqdm import tqdm +from datasets import load_dataset +from helm.common.media_object import MediaObject, MultimediaObject +from helm.common.hierarchical_logger import hlog + + +class CommonVoice15Scenario(Scenario): + """CommonVoice15 Scenario + + The most recent release of CommonVoice15 (Ardila et al, 2019) includes 114 languages. Over 50,000 + individuals have participated so far, resulting in 2,500 hours of collected audio. This is the largest + audio corpus in the public domain for speech recognition, both in terms of number of hours and number + of languages. The task is to recognize the speech from the audio sample. + + + + Paper: https://arxiv.org/abs/1912.06670 + Code: https://github.com/common-voice/common-voice + + Citation: + @article{ardila2019common, + title={Common voice: A massively-multilingual speech corpus}, + author={Ardila, Rosana and Branson, Megan and Davis, Kelly and + Henretty, Michael and Kohler, Michael and Meyer, Josh and Morais, + Reuben and Saunders, Lindsay and Tyers, Francis M and Weber, Gregor}, + journal={arXiv preprint arXiv:1912.06670}, + year={2019} + } + + """ + + HF_DATASET_NAME = "mozilla-foundation/common_voice_15_0" + + # Randomly selected 4 languages from 114 languages in the Common Voice 15 dataset following + # Qwen2-Audio (https://arxiv.org/abs/2407.10759). The full language is: + # https://huggingface.co/datasets/mozilla-foundation/common_voice_15_0/blob/main/languages.py + _COMMON_VOICE_TEST_LANG_TO_ID = OrderedDict( + [ + ("English", "en"), + ("Chinese_hk", "zh-HK"), + ("German", "de"), + ("French", "fr"), + ] + ) + + name = "common_voice_15" + description = "Speech recognition for 4 languages from 114 different languages in Common Voice 15 \ + ([Ardila et al, 2019](https://arxiv.org/abs/1912.06670))." + tags: List[str] = ["audio", "recognition"] + + def __init__(self, language: str) -> None: + super().__init__() + + language = language.capitalize() + if language not in CommonVoice15Scenario._COMMON_VOICE_TEST_LANG_TO_ID.keys(): + raise ValueError( + f"Invalid language. Valid languages are: {CommonVoice15Scenario._COMMON_VOICE_TEST_LANG_TO_ID.keys()}" + ) + + self._language: str = language + hlog( + "You need to sign in Huggingface to download the dataset. Please remember " + "to sign in to download the dataset." + ) + + def get_instances(self, output_path: str) -> List[Instance]: + instances: List[Instance] = [] + language_category = CommonVoice15Scenario._COMMON_VOICE_TEST_LANG_TO_ID[self._language] + for row in tqdm( + load_dataset( + CommonVoice15Scenario.HF_DATASET_NAME, + name=language_category, + cache_dir=output_path, + split=TEST_SPLIT, + ) + ): + local_audio_path = row["path"] + answer = row["sentence"] + input = Input( + multimedia_content=MultimediaObject([MediaObject(content_type="audio/mpeg", location=local_audio_path)]) + ) + references = [Reference(Output(text=answer), tags=[CORRECT_TAG])] + instances.append(Instance(input=input, references=references, split=TEST_SPLIT)) + return instances diff --git a/src/helm/benchmark/scenarios/audio_language/fleurs_scenario.py b/src/helm/benchmark/scenarios/audio_language/fleurs_scenario.py index 7e35d7dc60..582e6851ce 100644 --- a/src/helm/benchmark/scenarios/audio_language/fleurs_scenario.py +++ b/src/helm/benchmark/scenarios/audio_language/fleurs_scenario.py @@ -29,7 +29,15 @@ class FLEURSScenario(Scenario): Code: https://tensorflow.org/datasets/catalog/xtreme_s Citation: - + @inproceedings{conneau2023fleurs, + title={Fleurs: Few-shot learning evaluation of universal representations of speech}, + author={Conneau, Alexis and Ma, Min and Khanuja, Simran and Zhang, Yu and Axelrod, + Vera and Dalmia, Siddharth and Riesa, Jason and Rivera, Clara and Bapna, Ankur}, + booktitle={2022 IEEE Spoken Language Technology Workshop (SLT)}, + pages={798--805}, + year={2023}, + organization={IEEE} + } """ HF_DATASET_NAME = "google/xtreme_s" diff --git a/src/helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py b/src/helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py index 8801de47f8..abcc9776fd 100644 --- a/src/helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py +++ b/src/helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py @@ -11,7 +11,7 @@ Input, Output, ) -from helm.common.audio_utils import ensure_wav_file_exists_from_array +from helm.common.audio_utils import ensure_audio_file_exists_from_array from helm.common.general import ensure_directory_exists from helm.common.media_object import MediaObject, MultimediaObject @@ -65,7 +65,7 @@ def get_instances(self, output_path: str) -> List[Instance]: wav_path = os.path.join(wav_dir, row["audio"]["path"]) print(len(row["audio"]["array"])) print(list(row["audio"]["array"])[0:10]) - ensure_wav_file_exists_from_array( + ensure_audio_file_exists_from_array( wav_path, row["audio"]["array"], sample_rate=IEMOCAPAudioScenario.SAMPLE_RATE ) input = Input( diff --git a/src/helm/benchmark/scenarios/audio_language/meld_audio_scenario.py b/src/helm/benchmark/scenarios/audio_language/meld_audio_scenario.py index d3ccff75e7..dcb95cacbd 100644 --- a/src/helm/benchmark/scenarios/audio_language/meld_audio_scenario.py +++ b/src/helm/benchmark/scenarios/audio_language/meld_audio_scenario.py @@ -15,7 +15,7 @@ Input, Output, ) -from helm.common.audio_utils import ensure_wav_file_exists_from_array, get_array_from_audio_file +from helm.common.audio_utils import ensure_audio_file_exists_from_array, get_array_from_audio_file from helm.common.media_object import MediaObject, MultimediaObject from helm.common.general import ensure_directory_exists, ensure_file_downloaded @@ -98,7 +98,7 @@ def get_instances(self, output_path: str) -> List[Instance]: flac_file_name = f"dia{row.Dialogue_ID}_utt{row.Utterance_ID}.flac" flac_file_path = os.path.join(flac_dir, flac_file_name) audio_array = get_array_from_audio_file(flac_file_path, MELDAudioScenario.SAMPLE_RATE) - ensure_wav_file_exists_from_array(wav_file_path, audio_array, MELDAudioScenario.SAMPLE_RATE) + ensure_audio_file_exists_from_array(wav_file_path, audio_array, MELDAudioScenario.SAMPLE_RATE) input = Input( multimedia_content=MultimediaObject( media_objects=[MediaObject(location=wav_file_path, content_type="audio/wav")] diff --git a/src/helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py b/src/helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py index bf0f4e5f7b..df4b9e1e96 100644 --- a/src/helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py +++ b/src/helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py @@ -15,7 +15,7 @@ from tqdm import tqdm from datasets import load_dataset from helm.common.media_object import MediaObject, MultimediaObject -from helm.common.audio_utils import ensure_mp3_file_exists_from_array +from helm.common.audio_utils import ensure_audio_file_exists_from_array class MultilingualLibriSpeechScenario(Scenario): @@ -70,7 +70,7 @@ def get_instances(self, output_path: str) -> List[Instance]: ): local_audio_path = os.path.join(audio_save_dir, row["original_path"].split("/")[-1]) # download to the local path - ensure_mp3_file_exists_from_array(local_audio_path, row["audio"]["array"], row["audio"]["sampling_rate"]) + ensure_audio_file_exists_from_array(local_audio_path, row["audio"]["array"], row["audio"]["sampling_rate"]) answer = row["transcript"] input = Input( multimedia_content=MultimediaObject([MediaObject(content_type="audio/mpeg", location=local_audio_path)]) diff --git a/src/helm/benchmark/scenarios/audio_language/speech_robust_bench_scenario.py b/src/helm/benchmark/scenarios/audio_language/speech_robust_bench_scenario.py new file mode 100644 index 0000000000..84fc81a473 --- /dev/null +++ b/src/helm/benchmark/scenarios/audio_language/speech_robust_bench_scenario.py @@ -0,0 +1,91 @@ +"""Scenarios for audio models""" + +from typing import List +import os + +from helm.benchmark.scenarios.scenario import ( + Scenario, + Instance, + Reference, + TEST_SPLIT, + CORRECT_TAG, + Input, + Output, +) +from tqdm import tqdm +from datasets import load_dataset +from helm.common.media_object import MediaObject, MultimediaObject +from helm.common.audio_utils import ensure_audio_file_exists_from_array + + +class SpeechRobustBenchScenario(Scenario): + """Speech Robust Bench Scenario + + Speech Robust Bench (Shah et al, 2024) is a comprehensive benchmark for evaluating + the robustness of ASR models to diverse corruptions. SRB is composed of 114 input + perturbations which simulate an heterogeneous range of corruptions that ASR models + may encounter when deployed in the wild. In this scenario, we select four subsets + in the benchmark for evaluation. + + Paper: https://arxiv.org/abs/2403.07937 + Code: https://github.com/ahmedshah1494/speech_robust_bench + + Citation: + @article{shah2024speech, + title={Speech robust bench: A robustness benchmark for speech recognition}, + author={Shah, Muhammad A and Noguero, David Solans and Heikkila, Mikko A and Raj, + Bhiksha and Kourtellis, Nicolas}, + journal={arXiv preprint arXiv:2403.07937}, + year={2024} + } + """ + + HF_DATASET_NAME = "mshah1/speech_robust_bench" + + # Select four subsets of the dataset for the benchmark + SUBJECTS_DICT = { + "accented_cv": {"name": "accented_cv", "split": TEST_SPLIT, "type": "audio/mpeg"}, + "accented_cv_es": {"name": "accented_cv_es", "split": TEST_SPLIT, "type": "audio/mpeg"}, + "chime_far": {"name": "chime", "split": "farfield", "type": "audio/wav"}, + "chime_near": {"name": "chime", "split": "nearfield", "type": "audio/wav"}, + "ami_far": {"name": "in-the-wild-AMI", "split": "farfield", "type": "audio/wav"}, + "ami_near": {"name": "in-the-wild-AMI", "split": "nearfield", "type": "audio/wav"}, + } + + name = "speech_robust_bench" + description = ( + "Speech recognition for 4 datasets with a wide range of corruptions" + "([Shah et al, 2024](https://arxiv.org/abs/2403.07937))." + ) + tags: List[str] = ["audio", "recognition"] + + def __init__(self, subject: str) -> None: + super().__init__() + + self._subject = subject + if self._subject not in SpeechRobustBenchScenario.SUBJECTS_DICT.keys(): + raise ValueError(f"Invalid subject. Valid subjects are: {SpeechRobustBenchScenario.SUBJECTS_DICT.keys()}") + + def get_instances(self, output_path: str) -> List[Instance]: + instances: List[Instance] = [] + subject_name = SpeechRobustBenchScenario.SUBJECTS_DICT[self._subject]["name"] + subject_split = SpeechRobustBenchScenario.SUBJECTS_DICT[self._subject]["split"] + subject_type = SpeechRobustBenchScenario.SUBJECTS_DICT[self._subject]["type"] + audio_save_dir = os.path.join(output_path, "audio_files") + for row in tqdm( + load_dataset( + SpeechRobustBenchScenario.HF_DATASET_NAME, + name=subject_name, + cache_dir=output_path, + split=subject_split, + ) + ): + local_audio_path = os.path.join(audio_save_dir, row["audio"]["path"]) + ensure_audio_file_exists_from_array(local_audio_path, row["audio"]["array"], row["audio"]["sampling_rate"]) + answer = row["text"] + input = Input( + multimedia_content=MultimediaObject([MediaObject(content_type=subject_type, location=local_audio_path)]) + ) + references = [Reference(Output(text=answer), tags=[CORRECT_TAG])] + instances.append(Instance(input=input, references=references, split=TEST_SPLIT)) + return instances diff --git a/src/helm/benchmark/static/schema_speech.yaml b/src/helm/benchmark/static/schema_speech.yaml index 0613efb83a..f0ccae7b80 100644 --- a/src/helm/benchmark/static/schema_speech.yaml +++ b/src/helm/benchmark/static/schema_speech.yaml @@ -123,6 +123,19 @@ metrics: description: BLEU score based on [Post, (2018)](https://aclanthology.org/W18-6319/). lower_is_better: false + # Speech Recognition metrics + - name: wa + display_name: WA + short_display_name: WA + description: Word Accuracy based on the Word Error Rate. + lower_is_better: true + + - name: CA + display_name: WA + short_display_name: WA + description: Word Accuracy based on the Word Error Rate. + lower_is_better: true + ############################################################ perturbations: [] @@ -167,6 +180,11 @@ run_groups: - audio_mnist - covost2 - vocal_sound + - multilingual_librispeech + - fleurs + - audiocaps + - common_voice_15 + - speech_robust_bench - name: audio_mnist display_name: AudioMNIST @@ -299,4 +317,53 @@ run_groups: what: audio clips in the wild who: real speakers when: "2019" - language: English \ No newline at end of file + language: English + + - name: common_voice_15 + display_name: Common Voice 15 + description: > + The most recent release of CommonVoice15 (Ardila et al, 2019) includes 114 languages. Over 50,000 + individuals have participated so far, resulting in 2,500 hours of collected audio. This is the largest + audio corpus in the public domain for speech recognition, both in terms of number of hours and number + of languages. The task is to recognize the speech from the audio sample. + + The dataset contains the audio, transcriptions, location, speaker's gender, age, and accent + ([Ardila et al, 2020](https://arxiv.org/abs/1912.06670)). + metric_groups: + - accuracy + - efficiency + - general_information + environment: + main_name: word_accuracy + main_split: test + taxonomy: + task: audio recognition + what: audio, transcripts, location, speaker's information in 114 languages + who: real speakers + when: "2019" + language: 114 languages + + - name: robust_speech_bench + display_name: Robust Speech Bench + description: > + Speech Robust Bench (Shah et al, 2024) is a comprehensive benchmark for evaluating + the robustness of ASR models to diverse corruptions. SRB is composed of 114 input + perturbations which simulate an heterogeneous range of corruptions that ASR models + may encounter when deployed in the wild. In this scenario, we select 4 subsets: + accent_cv, accent_cv_es, chinme, and AIM for evaluation. + + The dataset contains the audio, transcriptions for all subsets + ([Shah et al, 2024](https://arxiv.org/abs/2403.07937)). + metric_groups: + - accuracy + - efficiency + - general_information + environment: + main_name: word_accuracy + main_split: test + taxonomy: + task: audio recognition + what: audio, transcripts of audio samples in a wide range of perturbations + who: real speakers + when: "2024" + language: English, Spanish \ No newline at end of file diff --git a/src/helm/common/audio_utils.py b/src/helm/common/audio_utils.py index 0f1b783d8a..efc1ee4301 100644 --- a/src/helm/common/audio_utils.py +++ b/src/helm/common/audio_utils.py @@ -5,42 +5,26 @@ import librosa import numpy as np -from scipy.io import wavfile +import soundfile as sf from helm.common.multimodal_request_utils import get_contents_as_bytes -def ensure_wav_file_exists_from_array(path: str, array: np.ndarray, sample_rate: int) -> None: - """Write the array to the wav file if it does not already exist. +def ensure_audio_file_exists_from_array(path: str, array: np.ndarray, sample_rate: int) -> None: + """Write the array to the wav or mp3 file if it does not already exist. Uses file locking and an atomic rename to avoid file corruption due to incomplete writes and concurrent writes.""" - if not path.endswith(".wav"): - raise ValueError(f"Path must end with .wav: {path}") + file_extension = os.path.splitext(path)[1] + if file_extension != ".wav" and file_extension != ".mp3": + raise ValueError(f"Path must end with .wav or .mp3: {path}") with FileLock(f"{path}.lock"): if os.path.exists(path): # Skip because file already exists return - path_prefix = path.removesuffix(".wav") - tmp_path = f"{path_prefix}.tmp.wav" - wavfile.write(filename=tmp_path, rate=sample_rate, data=array) - os.rename(tmp_path, path) - - -def ensure_mp3_file_exists_from_array(path: str, array: np.ndarray, sample_rate: int) -> None: - """Write the array to the mp3 file if it does not already exist. - - Uses file locking and an atomic rename to avoid file corruption due to incomplete writes and - concurrent writes.""" - if not path.endswith(".mp3"): - raise ValueError(f"Path must end with .mp3: {path}") - with FileLock(f"{path}.lock"): - if os.path.exists(path): - # Skip because file already exists - return - path_prefix = path.removesuffix(".mp3") - tmp_path = f"{path_prefix}.tmp.mp3" - wavfile.write(filename=tmp_path, rate=sample_rate, data=array) + path_prefix = path.removesuffix(file_extension) + tmp_path = f"{path_prefix}.tmp{file_extension}" + sf.write(tmp_path, array, samplerate=sample_rate) os.rename(tmp_path, path)