-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added whisperX support #125
base: main
Are you sure you want to change the base?
Changes from 4 commits
21f3238
07da572
a0aadf4
ae8b2f1
d545e5e
6256b61
8bee6b5
58e2c85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ RUN export DEBIAN_FRONTEND=noninteractive \ | |
python${PYTHON_VERSION}-venv \ | ||
python3-pip \ | ||
ffmpeg \ | ||
git \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here as well...
|
||
&& rm -rf /var/lib/apt/lists/* | ||
|
||
RUN ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \ | ||
|
@@ -35,6 +36,13 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui.css swagger-ui-assets/sw | |
COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-assets/swagger-ui-bundle.js | ||
|
||
RUN poetry install | ||
RUN $POETRY_VENV/bin/pip install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch | ||
RUN $POETRY_VENV/bin/pip install torch torchaudio pandas transformers nltk pyannote.audio \ | ||
--index-url https://download.pytorch.org/whl/cu118 \ | ||
--index-url https://pypi.org/simple/ | ||
|
||
RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \ | ||
&& cd whisperX \ | ||
&& $POETRY_VENV/bin/pip install --no-dependencies -e . | ||
|
||
EXPOSE 9000 | ||
CMD gunicorn --bind 0.0.0.0:9000 --workers 1 --timeout 0 app.webservice:app -k uvicorn.workers.UvicornWorker |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os | ||
from typing import BinaryIO, Union | ||
from io import StringIO | ||
from threading import Lock | ||
import torch | ||
import whisper | ||
import whisperx | ||
from whisper.utils import ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON | ||
|
||
model_name= os.getenv("ASR_MODEL", "base") | ||
hf_token= os.getenv("HF_TOKEN", "") | ||
x_models = dict() | ||
|
||
if torch.cuda.is_available(): | ||
device = "cuda" | ||
model = whisper.load_model(model_name).cuda() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use whisperx model for transcription? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your feedback! I fixed this issue, so whisperx should also be used for transcription now. |
||
if hf_token != "": | ||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device) | ||
else: | ||
device = "cpu" | ||
model = whisper.load_model(model_name) | ||
if hf_token != "": | ||
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device) | ||
model_lock = Lock() | ||
|
||
def transcribe( | ||
audio, | ||
task: Union[str, None], | ||
language: Union[str, None], | ||
initial_prompt: Union[str, None], | ||
word_timestamps: Union[bool, None], | ||
options: Union[dict, None], | ||
output | ||
): | ||
options_dict = {"task" : task} | ||
if language: | ||
options_dict["language"] = language | ||
if initial_prompt: | ||
options_dict["initial_prompt"] = initial_prompt | ||
with model_lock: | ||
result = model.transcribe(audio, **options_dict) | ||
|
||
# Load the required model and cache it | ||
# If we transcribe models in many differen languages, this may lead to OOM propblems | ||
if result["language"] in x_models: | ||
print('Using chached model') | ||
model_x, metadata = x_models[result["language"]] | ||
else: | ||
print(f'Loading model {result["language"]}') | ||
x_models[result["language"]] = whisperx.load_align_model(language_code=result["language"], device=device) | ||
model_x, metadata = x_models[result["language"]] | ||
|
||
# Align whisper output | ||
result = whisperx.align(result["segments"], model_x, metadata, audio, device, return_char_alignments=False) | ||
|
||
if options["diarize"]: | ||
if hf_token == "": | ||
print("Warning! HF_TOKEN is not set. Diarization may not wor as expected.") | ||
min_speakers = options["min_speakers"] | ||
max_speakers = options["max_speakers"] | ||
# add min/max number of speakers if known | ||
diarize_segments = diarize_model(audio, min_speakers, max_speakers) | ||
result = whisperx.assign_word_speakers(diarize_segments, result) | ||
|
||
outputFile = StringIO() | ||
write_result(result, outputFile, output) | ||
outputFile.seek(0) | ||
|
||
return outputFile | ||
|
||
def language_detection(audio): | ||
# load audio and pad/trim it to fit 30 seconds | ||
audio = whisper.pad_or_trim(audio) | ||
|
||
# make log-Mel spectrogram and move to the same device as the model | ||
mel = whisper.log_mel_spectrogram(audio).to(model.device) | ||
|
||
# detect the spoken language | ||
with model_lock: | ||
_, probs = model.detect_language(mel) | ||
detected_lang_code = max(probs, key=probs.get) | ||
|
||
return detected_lang_code | ||
|
||
def write_result( | ||
result: dict, file: BinaryIO, output: Union[str, None] | ||
): | ||
if(output == "srt"): | ||
WriteSRT(ResultWriter).write_result(result, file = file, options = {}) | ||
elif(output == "vtt"): | ||
WriteVTT(ResultWriter).write_result(result, file = file, options = {}) | ||
elif(output == "tsv"): | ||
WriteTSV(ResultWriter).write_result(result, file = file, options = {}) | ||
elif(output == "json"): | ||
WriteJSON(ResultWriter).write_result(result, file = file, options = {}) | ||
elif(output == "txt"): | ||
WriteTXT(ResultWriter).write_result(result, file = file, options = {}) | ||
else: | ||
return 'Please select an output method!' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add gcc and python3-dev packages here