-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add faster-whisper (ctranslate2) as option for Whisper annotation workflow #1017
base: master
Are you sure you want to change the base?
Changes from all commits
363c756
0f5a2e1
706a33a
79e47d8
d722e5b
f4a28af
b98703e
3c052f8
381709c
bbd556c
8903a7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,3 +1,4 @@ | ||||||
from functools import partial | ||||||
from typing import List, Optional, Union | ||||||
|
||||||
import click | ||||||
|
@@ -55,6 +56,36 @@ def workflows(): | |||||
@click.option( | ||||||
"-d", "--device", default="cpu", help="Device on which to run the inference." | ||||||
) | ||||||
@click.option( | ||||||
"--faster-whisper", | ||||||
is_flag=True, | ||||||
default=True, | ||||||
help="If True, use faster-whisper's implementation based on CTranslate2.", | ||||||
) | ||||||
@click.option( | ||||||
"--faster-whisper-use-vad", | ||||||
is_flag=True, | ||||||
default=True, | ||||||
help="If True, use faster-whisper's built-in voice activity detection (SileroVAD)." | ||||||
"Note: This requires onnxruntime to be installed.", | ||||||
) | ||||||
@click.option( | ||||||
"--faster-whisper-add-alignments", | ||||||
is_flag=True, | ||||||
default=False, | ||||||
help="If True, add word alignments using timestamps obtained using the cross-attention" | ||||||
"pattern and dynamic time warping (Note: Less accurate than forced alignment).", | ||||||
) | ||||||
@click.option( | ||||||
"--faster-whisper-compute-type", | ||||||
default="float16", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Otherwise it won't work on (some?) CPUs. |
||||||
help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.", | ||||||
) | ||||||
@click.option( | ||||||
"--faster-whisper-num-workers", | ||||||
default=1, | ||||||
help="Number of workers for parallelization across multiple GPUs.", | ||||||
) | ||||||
@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.") | ||||||
@click.option( | ||||||
"--force-nonoverlapping/--keep-overlapping", | ||||||
|
@@ -72,6 +103,11 @@ def annotate_with_whisper( | |||||
device: str, | ||||||
jobs: int, | ||||||
force_nonoverlapping: bool, | ||||||
faster_whisper: bool, | ||||||
faster_whisper_use_vad: bool, | ||||||
faster_whisper_compute_type: str, | ||||||
faster_whisper_add_alignments: bool, | ||||||
faster_whisper_num_workers: int, | ||||||
): | ||||||
""" | ||||||
Use OpenAI Whisper model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. | ||||||
|
@@ -83,7 +119,18 @@ def annotate_with_whisper( | |||||
Note: this is an experimental feature of Lhotse, and is not guaranteed to yield | ||||||
high quality of data. | ||||||
""" | ||||||
from lhotse import annotate_with_whisper as annotate_with_whisper_ | ||||||
if faster_whisper: | ||||||
from lhotse import annotate_with_faster_whisper | ||||||
|
||||||
annotate_with_whisper_ = partial( | ||||||
annotate_with_faster_whisper, | ||||||
compute_type=faster_whisper_compute_type, | ||||||
num_workers=faster_whisper_num_workers, | ||||||
vad_filter=faster_whisper_use_vad, | ||||||
add_alignments=faster_whisper_add_alignments, | ||||||
) | ||||||
else: | ||||||
from lhotse import annotate_with_whisper as annotate_with_whisper_ | ||||||
|
||||||
assert exactly_one_not_null(recordings_manifest, recordings_dir, cuts_manifest), ( | ||||||
"Options RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive " | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .faster_whisper import annotate_with_faster_whisper | ||
from .forced_alignment import align_with_torchaudio | ||
from .meeting_simulation import * | ||
from .whisper import annotate_with_whisper |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
import logging | ||
import warnings | ||
from concurrent.futures import as_completed | ||
from concurrent.futures.thread import ThreadPoolExecutor | ||
from typing import Any, Generator, List, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from lhotse import ( | ||
CutSet, | ||
MonoCut, | ||
Recording, | ||
RecordingSet, | ||
SupervisionSegment, | ||
add_durations, | ||
) | ||
from lhotse.qa import trim_supervisions_to_recordings | ||
from lhotse.supervision import AlignmentItem | ||
from lhotse.utils import fastcopy, is_module_available | ||
|
||
|
||
def annotate_with_faster_whisper( | ||
manifest: Union[RecordingSet, CutSet], | ||
model_name: str = "base", | ||
device: str = "cpu", | ||
force_nonoverlapping: bool = False, | ||
download_root: Optional[str] = None, | ||
compute_type: str = "default", | ||
num_workers: int = 1, | ||
vad_filter: bool = True, | ||
add_alignments: bool = False, | ||
**decode_options, | ||
) -> Generator[MonoCut, None, None]: | ||
""" | ||
Use OpenAI Whisper model via faster-whisper and CTranslate2 to annotate either | ||
RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. It will perform automatic segmentation, | ||
transcription, and language identification. If the first argument is a CutSet, it will | ||
overwrite the supervisions with the results of the inference. | ||
|
||
Note: this is an experimental feature of Lhotse, and is not guaranteed to yield | ||
high quality of data. | ||
|
||
See the original repo for more details: https://github.com/guillaumekln/faster-whisper | ||
|
||
:param manifest: a ``RecordingSet`` or ``CutSet`` object. | ||
:param language: specify the language if known upfront, otherwise it will be auto-detected. | ||
:param model_name: one of available Whisper variants (base, medium, large, etc.). | ||
:param device: Where to run the inference (cpu, cuda, etc.). | ||
:param force_nonoverlapping: if True, the Whisper segment time-stamps will be processed to make | ||
sure they are non-overlapping. | ||
:param download_root: Not supported by faster-whisper. Argument kept to maintain compatibility | ||
with annotate_with_whisper. Faster-whisper uses | ||
:param compute_type: Type to use for computation. | ||
See https://opennmt.net/CTranslate2/quantization.html. | ||
:param num_workers: Increasing the number of workers can improve the global throughput at the | ||
cost of increased memory usage. | ||
:param vad_filter: If True, use faster-whisper's built-in voice activity detection (SileroVAD). | ||
:param add_alignments: if True, add word alignments using timestamps obtained using the cross- | ||
attention pattern and dynamic time warping (Note: Less accurate than forced alignment). | ||
:param decode_options: additional options to pass to the ``whisper.transcribe`` function. | ||
:return: a generator of cuts (use ``CutSet.open_writer()`` to write them). | ||
""" | ||
assert is_module_available("faster_whisper"), ( | ||
"This function expects faster-whisper to be installed. " | ||
"You can install it via 'pip install faster-whisper' " | ||
"(see https://github.com/guillaumekln/faster-whisper/ for details)." | ||
) | ||
if not isinstance(manifest, RecordingSet) and not isinstance(manifest, CutSet): | ||
raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.") | ||
assert not vad_filter or is_module_available("onnxruntime"), ( | ||
"Use of VAD requires onnxruntime to be installed. " | ||
"You can install it via 'pip install onnxruntime' " | ||
"(see https://github.com/guillaumekln/faster-whisper/ for details)." | ||
) | ||
if vad_filter and add_alignments: | ||
warnings.warn( | ||
"Word timestamps can be very inaccurate when using VAD. We don't recommend using both " | ||
f"options together. See https://github.com/guillaumekln/faster-whisper/issues/125." | ||
) | ||
|
||
model = _initialize_model( | ||
model_name, device, compute_type, num_workers, download_root | ||
) | ||
with ThreadPoolExecutor(num_workers) as ex: | ||
futures = [] | ||
for item in manifest: | ||
futures.append( | ||
ex.submit( | ||
_process_single_manifest, | ||
item, | ||
model, | ||
force_nonoverlapping, | ||
vad_filter, | ||
add_alignments, | ||
**decode_options, | ||
) | ||
) | ||
for item in as_completed(futures): | ||
yield item.result() | ||
|
||
|
||
def _initialize_model( | ||
model_name: str, | ||
device: str, | ||
compute_type: str = "default", | ||
num_workers: int = 1, | ||
download_root: Optional[str] = None, | ||
): | ||
import torch | ||
from faster_whisper import WhisperModel | ||
|
||
# Parse device index | ||
device, _, idx = device.partition(":") | ||
if len(idx) > 0: | ||
device_index = int(idx) | ||
elif num_workers > 1 and device == "cuda": | ||
# Limit num_workers to available GPUs | ||
num_workers = min(num_workers, torch.cuda.device_count()) | ||
device_index = list(range(num_workers)) | ||
else: | ||
device_index = 0 | ||
model = WhisperModel( | ||
model_name, | ||
device=device, | ||
device_index=device_index, | ||
compute_type=compute_type, | ||
num_workers=num_workers, | ||
download_root=download_root, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the change that enables this option is still not released in pip, I suggest a bit of workaround here, otherwise it cannot be ran:
Note I also suggested a check for logger, on my installation |
||
) | ||
model.logger.setLevel(logging.WARNING) | ||
return model | ||
|
||
|
||
def _process_single_manifest( | ||
manifest: Union[Recording, MonoCut], | ||
model, | ||
force_nonoverlapping: bool, | ||
vad_filter: bool, | ||
add_alignments: bool = False, | ||
**decode_options, | ||
) -> MonoCut: | ||
if isinstance(manifest, Recording): | ||
if manifest.num_channels > 1: | ||
logging.warning( | ||
f"Skipping recording '{manifest.id}'. It has {manifest.num_channels} channels, " | ||
f"but we currently only support mono input." | ||
) | ||
return [] | ||
recording_id = manifest.id | ||
else: | ||
recording_id = manifest.recording_id | ||
audio = np.squeeze(manifest.resample(16000).load_audio()) | ||
segments, info = model.transcribe( | ||
audio=audio, | ||
word_timestamps=add_alignments, | ||
vad_filter=vad_filter, | ||
**decode_options, | ||
) | ||
# Create supervisions from segments while filtering out those with negative duration. | ||
if add_alignments: | ||
supervisions = [ | ||
SupervisionSegment( | ||
id=f"{manifest.id}-{segment_id:06d}", | ||
recording_id=recording_id, | ||
start=round(segment.start, ndigits=8), | ||
duration=add_durations( | ||
segment.end, -segment.start, sampling_rate=16000 | ||
), | ||
text=segment.text.strip(), | ||
language=info.language, | ||
).with_alignment( | ||
"word", | ||
[ | ||
AlignmentItem( | ||
symbol=ws.word.strip(), | ||
start=round(ws.start, ndigits=8), | ||
duration=round(ws.end - ws.start, ndigits=8), | ||
score=round(ws.probability, ndigits=3), | ||
) | ||
for ws in segment.words | ||
], | ||
) | ||
for segment_id, segment in enumerate(segments) | ||
if segment.end - segment.start > 0 | ||
] | ||
else: | ||
supervisions = [ | ||
SupervisionSegment( | ||
id=f"{manifest.id}-{segment_id:06d}", | ||
recording_id=recording_id, | ||
start=round(segment.start, ndigits=8), | ||
duration=add_durations( | ||
segment.end, -segment.start, sampling_rate=16000 | ||
), | ||
text=segment.text.strip(), | ||
language=info.language, | ||
) | ||
for segment_id, segment in enumerate(segments) | ||
if segment.end - segment.start > 0 | ||
] | ||
|
||
if isinstance(manifest, Recording): | ||
cut = manifest.to_cut() | ||
if supervisions: | ||
supervisions = ( | ||
_postprocess_timestamps(supervisions) | ||
if force_nonoverlapping | ||
else supervisions | ||
) | ||
cut.supervisions = list( | ||
trim_supervisions_to_recordings( | ||
recordings=manifest, supervisions=supervisions, verbose=False | ||
) | ||
) | ||
else: | ||
cut = fastcopy( | ||
manifest, | ||
supervisions=_postprocess_timestamps(supervisions) | ||
if force_nonoverlapping | ||
else supervisions, | ||
) | ||
|
||
return cut | ||
|
||
|
||
def _postprocess_timestamps(supervisions: List[SupervisionSegment]): | ||
""" | ||
Whisper tends to have a lot of overlapping segments due to inaccurate end timestamps. | ||
Under a strong assumption that the input speech is non-overlapping, we can fix that | ||
by always truncating to the start timestamp of the next segment. | ||
""" | ||
from cytoolz import sliding_window | ||
|
||
supervisions = sorted(supervisions, key=lambda s: s.start) | ||
|
||
if len(supervisions) < 2: | ||
return supervisions | ||
out = [] | ||
for cur, nxt in sliding_window(2, supervisions): | ||
if cur.end > nxt.start: | ||
cur = cur.trim(end=nxt.start) | ||
out.append(cur) | ||
out.append(nxt) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change to:
Otherwise it can't be turned off.