diff --git a/advanced/whisperx/main.py b/advanced/whisperx/main.py index 8ea87ac..7c01c7b 100644 --- a/advanced/whisperx/main.py +++ b/advanced/whisperx/main.py @@ -1,16 +1,13 @@ import base64 -import json import os import sys import time -import tempfile -from threading import Lock -from typing import Dict, Optional, Union -import uuid +from typing import List, Dict, Optional, Union +from threading import Lock import numpy as np -from leptonai.photon import Photon, FileParam, HTTPException +from leptonai.photon import Photon, FileParam, HTTPException, get_file_content from loguru import logger # Note: instead of importing whisperx in the main file, we import it in the functions that @@ -19,179 +16,168 @@ # import whisperx -class WhisperXBackground(Photon): +class WhisperX(Photon): """ A WhisperX photon that serves the [WhisperX](https://github.com/m-bain/whisperX) model. - The photon exposes two endpoints: `/run` and `/run_upload` that deals with files/urls - and uploaded contents respectively. See the docs of each for details. - - Different from the main WhisperX photon, this photon starts background tassks in the - background, so it handles requests more efficiently. For example, for the main photon, - if you run a prediction, it will block the server until the prediction is done, giving - a pretty bad user experience. This photon, on the other hand, will return immediately - with a task id. The user can then use the task id to query the status of the task, and - get the result when it is done. + The photon exposes a single endpoint "/run" that takes an audio file as input, and returns + the transcription, and alignment, and diarization results. """ + # Note: openai-whisper implicitly requires triton 2.1.0, which in turn might be in conflict + # with non-version-pinned torch and torchaudio. As a result, we will pin all three versions + # here. requirement_dependency = [ - "leptonai", "torch", "torchaudio", - "git+https://github.com/m-bain/whisperx.git", + "triton", + "openai-whisper", + "leptonai", + "pyannote.audio", + "git+https://github.com/m-bain/whisperx.git@e9c507ce5dea0f93318746411c03fed0926b70be", ] system_dependencies = ["ffmpeg"] - # Parameters for the photon. - # The photon will need to have a storage folder - OUTPUT_ROOT = os.environ.get("WHISPERX_OUTPUT_ROOT", "/tmp/whisperx") - INPUT_FILE_EXTENSION = ".npy" - OUTPUT_FILE_EXTENSION = ".json" - OUTPUT_MAXIMUM_AGE = 60 * 60 * 24 # 1 day - CLEANUP_INTERVAL = 60 * 60 # 1 hour - LAST_CLEANUP_TIME = time.time() + deployment_template = { + "resource_shape": "gpu.a10", + "env": { + "WHISPER_MODEL": "large-v3", + # maximum audio length that the api allows. In default, we will use + # 10 minutes. If you are deploying things on your own, you can change + # it to be longer. + "MAX_LENGTH_IN_SECONDS": "600", + }, + "secret": [ + "HUGGING_FACE_HUB_TOKEN", + ] + } + + handler_max_concurrency = 8 + SUPPORTED_LANGUAGES = {"en", "fr", "de", "es", "it", "ja", "zh", "nl", "uk", "pt"} + # The main language for the model + MAIN_LANGUAGE = "en" + # batch size that is benchmarked to be the best balance on A10 + DEFAULT_BATCH_SIZE = 16 + + # Because each alignment language takes a nontrivial amount of memory, + # we only keep languages that we find are commonly called, and load other + # models on-demand. You can change this to host more alignment models in a + # warm state at the cost of more memory. + ALIGNMENT_LANGUAGE_TO_KEEP = {"en", "zh", "es"} def init(self): + import torch import whisperx + from whisperx.asr import FasterWhisperPipeline - logger.info("Initializing WhisperXPhoton") + logger.info("Initializing WhisperX") - # 0. Create output root, and launch a thread to clean up old files - os.makedirs(self.OUTPUT_ROOT, exist_ok=True) + self.USE_FASTER_WHISPER = True + self.WHISPER_MODEL = os.environ["WHISPER_MODEL"] + self.MAX_LENGTH_IN_SECONDS = int(os.environ["MAX_LENGTH_IN_SECONDS"]) # 1. Load whisper model - self.hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", None) + self.hf_token = os.environ["HUGGING_FACE_HUB_TOKEN"] if not self.hf_token: logger.error("Please set the environment variable HUGGING_FACE_HUB_TOKEN.") sys.exit(1) - self.device = "cuda" - compute_type = "float16" - self._model = whisperx.load_model( - "large-v2", self.device, compute_type=compute_type + if torch.cuda.is_available(): + self.device = "cuda" + compute_type = "float16" + else: + self.device = "cpu" + compute_type = "float32" + + # 1. load whisper model + # We keep a main model as MAIN_LANGUAGE so we don't need to always reload + # tokenizers. We also keep a multilingual model that can handle all languages. + self._main_model = whisperx.load_model( + self.WHISPER_MODEL, self.device, compute_type=compute_type, language=self.MAIN_LANGUAGE, + ) + self._multilingual_model = FasterWhisperPipeline( + model=self._main_model.model, + vad=self._main_model.vad_model, + options=self._main_model.options, + tokenizer=None, + language=None, + suppress_numerals=self._main_model.suppress_numerals, + vad_params=self._main_model._vad_params, ) + # For the main model, inference is not thread safe (because of some underlying cuda memory + # accesses). As a result, whenever we use the transcribe model, we need to lock it. + self.transcribe_model_lock = Lock() - # 2. load whisper align model + # 2. load whisper align model. Alignment models are language specific, so we will basically + # load them as a dictionary. In addition, we only load models that are in ALIGNMENT_LANGUAGE_TO_KEEP. self._model_a = {} self._metadata = {} - self._model_a["en"], self._metadata["en"] = whisperx.load_align_model( - language_code="en", device=self.device - ) + for lang in self.ALIGNMENT_LANGUAGE_TO_KEEP: + self._model_a[lang], self._metadata[lang] = whisperx.load_align_model( + language_code=lang, device=self.device + ) + # Since we don't know if whisper's align function is perfectly thread safe or not, we + # will lock it as well. self.align_model_lock = Lock() + + # 3. load whisper diarize model. Diarization model right now is thread safe. self._diarize_model = whisperx.DiarizationPipeline( - use_auth_token=self.hf_token, device=self.device + model_name='pyannote/speaker-diarization@2.1', use_auth_token=self.hf_token, device=self.device ) - - def _gen_unique_filename(self) -> str: - return str(uuid.uuid4()) - - def _regular_clean_up(self): - if time.time() - self.LAST_CLEANUP_TIME < self.CLEANUP_INTERVAL: - return - logger.info(f"Cleaning up {self.OUTPUT_ROOT} regularly") - self.LAST_CLEANUP_TIME = time.time() - for filename in os.listdir(self.OUTPUT_ROOT): - if filename.endswith(self.OUTPUT_FILE_EXTENSION): - filepath = os.path.join(self.OUTPUT_ROOT, filename) - # Checks if files are older than 1 hour. If so, delete them. - if ( - os.path.isfile(filepath) - and time.time() - os.path.getmtime(filepath) - > self.OUTPUT_MAXIMUM_AGE - ): - os.remove(filepath) - - def _run_whisperx( - self, - audio: Union[np.ndarray, str], - language: Optional[str] = None, - min_speakers: Optional[int] = None, - max_speakers: Optional[int] = None, - transcribe_only: bool = False, - task_id: Optional[str] = None, - ) -> Optional[Dict]: - """ - The main function that is called by the others. - """ + self._diarize_model_lock = Lock() + + def _transcribe(self, audio: np.ndarray, audio_file, language: Optional[str] = None): + logger.debug(f"transcribe: aquiring lock") + with self.transcribe_model_lock: + logger.debug(f"transcribe: lock acquired") + if language == self.MAIN_LANGUAGE: + result = self._main_model.transcribe( + audio, + batch_size=self.DEFAULT_BATCH_SIZE, + language=language) + else: + result = self._multilingual_model.transcribe( + audio, + batch_size=self.DEFAULT_BATCH_SIZE, + language=language) + logger.debug(f"transcribe: lock released") + return result + + def _align(self, result, audio): + # Run alignment import whisperx - - batch_size = 16 - start_time = time.time() - - audio_filename = None - if isinstance(audio, str): - # Load the numpy array from the file - audio_filename = audio - logger.debug(f"Loading audio from file {audio_filename}.") - audio = np.load(audio) - logger.debug(f"started processing audio of length {len(audio)}.") - if task_id: - logger.debug(f"task_id: {task_id}") - result = self._model.transcribe(audio, batch_size=batch_size, language=language) - if len(result["segments"]) == 0: - logger.debug("Empty result from whisperx. Directly return empty.") - return [] - - if not transcribe_only: - with self.align_model_lock: - if result["language"] not in self._model_a: - ( - self._model_a[result["language"]], - self._metadata[result["language"]], - ) = whisperx.load_align_model( - language_code=result["language"], device=self.device - ) + logger.debug("Start alignment") + if result["language"] in self.SUPPORTED_LANGUAGES: + model_a = self._model_a[result["language"]] + metadata_a = self._metadata[result["language"]] + else: + # load model_a and metadata on-demand + model_a, metadata_a = whisperx.load_align_model( + language_code=result["language"], device=self.device + ) + with self.align_model_lock: result = whisperx.align( result["segments"], - self._model_a[result["language"]], - self._metadata[result["language"]], + model_a, + metadata_a, audio, self.device, return_char_alignments=False, ) - # When there is no active diarization, the diarize model throws a KeyError. - # In this case, we simply skip diarization. - try: - if ( - min_speakers - and max_speakers - and min_speakers <= max_speakers - and min_speakers > 0 - ): - diarize_segments = self._diarize_model( - audio, min_speakers=min_speakers, max_speakers=max_speakers - ) - else: - # ignore the hint and do diarization. - diarize_segments = self._diarize_model(audio) - except Exception as e: - logger.error(f"Error in diarization: {e}. Skipping diarization.") - else: - result = whisperx.assign_word_speakers(diarize_segments, result) - - if audio_filename: - os.remove(audio_filename) - if task_id is None: - total_time = time.time() - start_time - logger.debug( - f"finished processing task {task_id}. Audio len: {audio.size} Total" - f" time: {total_time} ({audio.size / 16000 / total_time} x realtime)" - ) - return result["segments"] - else: - # return result["segments"] # segments are now assigned speaker IDs - output_filepath = os.path.join( - self.OUTPUT_ROOT, task_id + self.OUTPUT_FILE_EXTENSION - ) - json.dump(result["segments"], open(output_filepath, "w")) - self._regular_clean_up() - total_time = time.time() - start_time - logger.debug( - f"finished processing task {task_id}. Audio len: {audio.size} Total" - f" time: {total_time} ({audio.size / 16000 / total_time} x realtime)" + logger.debug("alignment done.") + return result + + def _diarize(self, audio, min_speakers, max_speakers): + logger.debug("Start diarization") + with self._diarize_model_lock: + result = self._diarize_model( + audio, + min_speakers=min_speakers, + max_speakers=max_speakers, ) - return + logger.debug("diarization done.") + return result @Photon.handler( example={ @@ -200,16 +186,17 @@ def _run_whisperx( ), "language": "en", "transcribe_only": True, - } + }, + cancel_on_disconnect=1.0, ) def run( self, input: Union[FileParam, str], - language: Optional[str] = None, + language: Optional[str] = "en", min_speakers: Optional[int] = None, max_speakers: Optional[int] = None, - transcribe_only: bool = False, - ) -> Dict: + transcribe_only: bool = True, + ) -> List: """ Runs transcription, alignment, and diarization for the input. @@ -222,17 +209,15 @@ def run( - min_speakers(optional): the hint for minimum number of speakers for diarization. - max_speakers(optional): the hint for maximum number of speakers for diarization. - transcribe_only(optional): if True, only transcribe the audio, and skip alignment - and diarization. + and diarization. Default to True. - Returns: - - result: if the input audio is less than 60 seconds, we will directly return the - result. - - task: a dictionary with key `task_id` and value the task uuid. Use `status(**task)` - to query the task status, and `get_result(**task)` to get the result when the - status is "ok". + - result: The transcribe and/or aligned and diarized result. If transcribe_only, + the result contains only the transcription """ import whisperx + # Check input if language is not None and language not in self.SUPPORTED_LANGUAGES: raise HTTPException( 400, @@ -254,122 +239,62 @@ def run( f" {max_speakers}", ) - try: - if isinstance(input, FileParam): - # We write the content at input.file to a temporary file, then call whisperx.load_audio - # to load the audio from the temporary file. - with tempfile.NamedTemporaryFile() as f: - f.write(input.file.read()) - f.flush() - filename = f.name - audio = whisperx.load_audio(filename) - elif input.startswith("http://") or input.startswith("https://"): - # This is a url, we can directly pass it to whisperx.load_audio - audio = whisperx.load_audio(input) - else: - # As a fallback option, we will assume that this is a base64 encoded string. - # We write the content at input to a temporary file, then call whisperx.load_audio - # to load the audio from the temporary file. - if input.startswith("data:audio/wav;base64,"): - input = input[22:] - with tempfile.NamedTemporaryFile() as f: - decoded_data = base64.b64decode(input) - f.write(decoded_data) - f.flush() - filename = f.name - audio = whisperx.load_audio(filename) - except Exception: + start_time = time.time() + logger.debug(f"Start processing audio {input}") + audio_file = get_file_content(input, return_file=True) + audio = whisperx.load_audio(audio_file.name) + if audio.size > self.MAX_LENGTH_IN_SECONDS * 16000: raise HTTPException( 400, - "Invalid input. Please check your input, it should be a FileParam" - " (python), a url, or a base64 encoded string.", + f"Audio length {audio.size / 16000} seconds is longer than the maximum" + f" allowed length {self.MAX_LENGTH_IN_SECONDS} seconds.", ) + logger.debug(f"started processing audio of length {len(audio)}.") + # Note: audio_file is basically provided so we can run whisper if needed. Whisper + # uses a different loading mechanism, and as a result it is slightly different from + # the whisperx loaded audio. + result = self._transcribe(audio, audio_file, language=language) + logger.debug(f"Transcription done.") - SAMPLE_RATE = 16000 # The default sample rate that whisperx uses - if len(audio) < SAMPLE_RATE * 60: - # For audio shorter than 1 minute, directly compute and return the result. - ret = self._run_whisperx( - audio, - language, - min_speakers, - max_speakers, - transcribe_only, - ) - if ret is None: - raise HTTPException( - 500, "You hit a programming error - please let us know." - ) - else: - return ret - elif len(audio) > SAMPLE_RATE * 60 * 60: - # For audio longer than 90 minutes, raise an error. - raise HTTPException(400, "Audio longer than 60 minutes is not supported.") - else: - task_id = self._gen_unique_filename() - input_filepath = os.path.join( - self.OUTPUT_ROOT, task_id + self.INPUT_FILE_EXTENSION - ) - np.save(input_filepath, audio) - self.add_background_task( - self._run_whisperx, - input_filepath, - language, - min_speakers, - max_speakers, - transcribe_only, - task_id, + if len(result["segments"]) == 0: + logger.debug("Empty result from whisperx. Directly return empty.") + return [] + + if transcribe_only: + total_time = time.time() - start_time + logger.debug( + f"finished processing audio of len {audio.size}. Total" + f" time: {total_time} ({audio.size / 16000 / total_time} x realtime)" ) - self._regular_clean_up() - return {"task_id": task_id} + return result["segments"] - @Photon.handler - def status(self, task_id: str) -> Dict[str, str]: - """ - Returns the status of the task. It could be "invalid_task_id", "pending", "not_found", or "ok". - """ + # Run alignment and diarization + result = self._align(result, audio) + + # When there is no active diarization, the diarize model throws a KeyError. + # In this case, we simply skip diarization. try: - _ = uuid.UUID(task_id, version=4) - except ValueError: - return {"status": "invalid_task_id"} - input_filepath = os.path.join( - self.OUTPUT_ROOT, task_id + self.INPUT_FILE_EXTENSION - ) - output_filepath = os.path.join( - self.OUTPUT_ROOT, task_id + self.OUTPUT_FILE_EXTENSION - ) - if not os.path.exists(output_filepath): - if os.path.exists(input_filepath): - return {"status": "pending"} - else: - return {"status": "not_found"} + diarize_segments = self._diarize(audio, min_speakers, max_speakers) + except Exception as e: + logger.error(f"Error in diarization: {e}. Skipping diarization.") else: - return {"status": "ok"} + result = whisperx.assign_word_speakers(diarize_segments, result) - @Photon.handler - def get_result(self, task_id: str) -> Dict: - """ - Gets the result of the whisper x task. If the task is not finished, it will raise a 404 error. - Use `status(task_id=task_id)` to check if the task is finished. - """ - if self.status(task_id=task_id)["status"] == "ok": - output_filepath = os.path.join( - self.OUTPUT_ROOT, task_id + self.OUTPUT_FILE_EXTENSION - ) - return json.load(open(output_filepath, "r")) - else: - raise HTTPException(status_code=404, detail="result not found") + total_time = time.time() - start_time + logger.debug( + f"finished processing audio of len {audio.size}. Total" + f" time: {total_time} ({audio.size / 16000 / total_time} x realtime)" + ) + return result["segments"] - def queue_length(self) -> int: + @Photon.handler + def model(self) -> str: """ - Returns the current queue length. + Returns the whisper model string. """ - return len([ - f - for f in os.listdir(self.OUTPUT_ROOT) - if f.endswith(self.INPUT_FILE_EXTENSION) - ]) + return self.WHISPER_MODEL if __name__ == "__main__": - p = WhisperXBackground() + p = WhisperX() p.launch() diff --git a/advanced/whisperx/requirements.txt b/advanced/whisperx/requirements.txt index 6a80f2e..1ff81f9 100644 --- a/advanced/whisperx/requirements.txt +++ b/advanced/whisperx/requirements.txt @@ -1,4 +1,5 @@ -leptonai torch torchaudio -git+https://github.com/m-bain/whisperx.git +leptonai +pyannote.audio +git+https://github.com/m-bain/whisperx.git@e9c507ce5dea0f93318746411c03fed0926b70be