diff --git a/Dockerfile b/Dockerfile index 42d638c..5466ea8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,7 @@ RUN export DEBIAN_FRONTEND=noninteractive \ && apt-get -qq update \ && apt-get -qq install --no-install-recommends \ ffmpeg \ + git \ && rm -rf /var/lib/apt/lists/* RUN python3 -m venv $POETRY_VENV \ @@ -24,4 +25,10 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-ass RUN poetry config virtualenvs.in-project true RUN poetry install +RUN /app/.venv/bin/pip install pandas transformers nltk pyannote.audio +RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \ + && cd whisperX \ + && $POETRY_VENV/bin/pip install -e . + +EXPOSE 9000 ENTRYPOINT ["gunicorn", "--bind", "0.0.0.0:9000", "--workers", "1", "--timeout", "0", "app.webservice:app", "-k", "uvicorn.workers.UvicornWorker"] diff --git a/Dockerfile.gpu b/Dockerfile.gpu index c3a8eee..5b31896 100644 --- a/Dockerfile.gpu +++ b/Dockerfile.gpu @@ -11,6 +11,7 @@ RUN export DEBIAN_FRONTEND=noninteractive \ python${PYTHON_VERSION}-venv \ python3-pip \ ffmpeg \ + git \ && rm -rf /var/lib/apt/lists/* RUN ln -s -f /usr/bin/python${PYTHON_VERSION} /usr/bin/python3 && \ @@ -35,6 +36,13 @@ COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui.css swagger-ui-assets/sw COPY --from=swagger-ui /usr/share/nginx/html/swagger-ui-bundle.js swagger-ui-assets/swagger-ui-bundle.js RUN poetry install -RUN $POETRY_VENV/bin/pip install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch +RUN /app/.venv/bin/pip install torch torchaudio pandas transformers nltk pyannote.audio \ + --index-url https://download.pytorch.org/whl/cu118 \ + --index-url https://pypi.org/simple/ -CMD gunicorn --bind 0.0.0.0:9000 --workers 1 --timeout 0 app.webservice:app -k uvicorn.workers.UvicornWorker +RUN git clone --depth 1 https://github.com/m-bain/whisperX.git \ + && cd whisperX \ + && $POETRY_VENV/bin/pip install --no-dependencies -e . + +EXPOSE 9000 +CMD gunicorn --bind 0.0.0.0:9000 --workers 1 --timeout 0 app.webservice:app -k uvicorn.workers.UvicornWorker \ No newline at end of file diff --git a/README.md b/README.md index 75bd94c..3eacd71 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Current release (v1.2.0) supports following whisper models: - [openai/whisper](https://github.com/openai/whisper)@[v20230918](https://github.com/openai/whisper/releases/tag/v20230918) - [guillaumekln/faster-whisper](https://github.com/guillaumekln/faster-whisper)@[0.9.0](https://github.com/guillaumekln/faster-whisper/releases/tag/v0.9.0) +- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1) ## Quick Usage diff --git a/app/faster_whisper/core.py b/app/faster_whisper/core.py index aa6563a..987b64b 100644 --- a/app/faster_whisper/core.py +++ b/app/faster_whisper/core.py @@ -25,6 +25,7 @@ def transcribe( language: Union[str, None], initial_prompt: Union[str, None], word_timestamps: Union[bool, None], + options: Union[dict, None], output, ): options_dict = {"task": task} diff --git a/app/mbain_whisperx/core.py b/app/mbain_whisperx/core.py new file mode 100644 index 0000000..9dd2a59 --- /dev/null +++ b/app/mbain_whisperx/core.py @@ -0,0 +1,105 @@ +import os +from typing import BinaryIO, Union +from io import StringIO +from threading import Lock +import torch +import whisperx +import whisper +from whisperx.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteSRT, WriteVTT, WriteTSV, WriteJSON + +model_name= os.getenv("ASR_MODEL", "base") +hf_token= os.getenv("HF_TOKEN", "") +x_models = dict() + +if torch.cuda.is_available(): + device = "cuda" + model = whisperx.load_model(model_name, device=device) + if hf_token != "": + diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device) +else: + device = "cpu" + model = whisperx.load_model(model_name, device=device) + if hf_token != "": + diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device) +model_lock = Lock() + +def transcribe( + audio, + task: Union[str, None], + language: Union[str, None], + initial_prompt: Union[str, None], + word_timestamps: Union[bool, None], + options: Union[dict, None], + output +): + options_dict = {"task" : task} + if language: + options_dict["language"] = language + if initial_prompt: + options_dict["initial_prompt"] = initial_prompt + with model_lock: + result = model.transcribe(audio, **options_dict) + + # Load the required model and cache it + # If we transcribe models in many differen languages, this may lead to OOM propblems + if result["language"] in x_models: + print('Using chached model') + model_x, metadata = x_models[result["language"]] + else: + print(f'Loading model {result["language"]}') + x_models[result["language"]] = whisperx.load_align_model(language_code=result["language"], device=device) + model_x, metadata = x_models[result["language"]] + + # Align whisper output + result = whisperx.align(result["segments"], model_x, metadata, audio, device, return_char_alignments=False) + + if options["diarize"]: + if hf_token == "": + print("Warning! HF_TOKEN is not set. Diarization may not wor as expected.") + min_speakers = options["min_speakers"] + max_speakers = options["max_speakers"] + # add min/max number of speakers if known + diarize_segments = diarize_model(audio, min_speakers, max_speakers) + result = whisperx.assign_word_speakers(diarize_segments, result) + + outputFile = StringIO() + write_result(result, outputFile, output) + outputFile.seek(0) + + return outputFile + +def language_detection(audio): + # load audio and pad/trim it to fit 30 seconds + audio = whisper.pad_or_trim(audio) + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(model.device) + + # detect the spoken language + with model_lock: + _, probs = model.detect_language(mel) + detected_lang_code = max(probs, key=probs.get) + + return detected_lang_code + +def write_result( + result: dict, file: BinaryIO, output: Union[str, None] +): + if(output == "srt"): + if hf_token != "": + WriteSRT(SubtitlesWriter).write_result(result, file = file, options = {}) + else: + WriteSRT(ResultWriter).write_result(result, file = file, options = {}) + elif(output == "vtt"): + if hf_token != "": + WriteVTT(SubtitlesWriter).write_result(result, file = file, options = {}) + else: + WriteVTT(ResultWriter).write_result(result, file = file, options = {}) + elif(output == "tsv"): + WriteTSV(ResultWriter).write_result(result, file = file, options = {}) + elif(output == "json"): + WriteJSON(ResultWriter).write_result(result, file = file, options = {}) + elif(output == "txt"): + WriteTXT(ResultWriter).write_result(result, file = file, options = {}) + else: + return 'Please select an output method!' diff --git a/app/openai_whisper/core.py b/app/openai_whisper/core.py index df3af05..9e1bdea 100644 --- a/app/openai_whisper/core.py +++ b/app/openai_whisper/core.py @@ -23,6 +23,7 @@ def transcribe( language: Union[str, None], initial_prompt: Union[str, None], word_timestamps: Union[bool, None], + options: Union[dict, None], output ): options_dict = {"task": task} diff --git a/app/webservice.py b/app/webservice.py index c3846d3..d919a4a 100644 --- a/app/webservice.py +++ b/app/webservice.py @@ -12,8 +12,12 @@ from whisper import tokenizer ASR_ENGINE = os.getenv("ASR_ENGINE", "openai_whisper") +HF_TOKEN = os.getenv("HF_TOKEN", "") + if ASR_ENGINE == "faster_whisper": from .faster_whisper.core import transcribe, language_detection +elif ASR_ENGINE == "whisperx": + from .mbain_whisperx.core import transcribe, language_detection else: from .openai_whisper.core import transcribe, language_detection @@ -59,16 +63,44 @@ async def index(): @app.post("/asr", tags=["Endpoints"]) -async def asr( - task: Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), - language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES), - initial_prompt: Union[str, None] = Query(default=None), - audio_file: UploadFile = File(...), - encode: bool = Query(default=True, description="Encode audio first through ffmpeg"), - output: Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), - word_timestamps: bool = Query(default=False, description="Word level timestamps") +def asr( + task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]), + language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES), + initial_prompt: Union[str, None] = Query(default=None), + audio_file: UploadFile = File(...), + encode : bool = Query(default=True, description="Encode audio first through ffmpeg"), + output : Union[str, None] = Query(default="txt", enum=["txt", "vtt", "srt", "tsv", "json"]), + word_timestamps : bool = Query( + default=False, + description="World level timestamps", + include_in_schema=(True if ASR_ENGINE == "faster_whisper" else False) + ), + diarize : bool = Query( + default=False, + description="Diarize the input", + include_in_schema=(True if ASR_ENGINE == "whisperx" and HF_TOKEN != "" else False)), + min_speakers : Union[int, None] = Query( + default=None, + description="Min speakers in this file", + include_in_schema=(True if ASR_ENGINE == "whisperx" else False)), + max_speakers : Union[int, None] = Query( + default=None, + description="Max speakers in this file", + include_in_schema=(True if ASR_ENGINE == "whisperx" else False)), ): - result = transcribe(load_audio(audio_file.file, encode), task, language, initial_prompt, word_timestamps, output) + result = transcribe( + load_audio(audio_file.file, encode), + task, + language, + initial_prompt, + word_timestamps, + { + "diarize": diarize, + "min_speakers": min_speakers, + "max_speakers": max_speakers + }, + output) + return StreamingResponse( result, media_type="text/plain", diff --git a/docs/index.md b/docs/index.md index 6dff661..38a9682 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,6 +5,7 @@ Current release (v1.2.0) supports following whisper models: - [openai/whisper](https://github.com/openai/whisper)@[v20230918](https://github.com/openai/whisper/releases/tag/v20230918) - [guillaumekln/faster-whisper](https://github.com/guillaumekln/faster-whisper)@[0.9.0](https://github.com/guillaumekln/faster-whisper/releases/tag/v0.9.0) +- [whisperX](https://github.com/m-bain/whisperX)@[v3.1.1](https://github.com/m-bain/whisperX/releases/tag/v3.1.1) ## Quick Usage