-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from leptonai/yqdemo
feat(examples): add whisperx example
- Loading branch information
Showing
4 changed files
with
375 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# WhisperX | ||
|
||
This example demonstrates how to run the WhisperX model on Lepton. | ||
|
||
[WhisperX](https://github.com/m-bain/whisperX) is a JAX (optimized) port of the openai whisper model. It chunks audio data into segments and then performs batch inference to gain speedup. | ||
|
||
## Note on custom environment | ||
|
||
Similar to the Whisper JAX example, if you are running locally, we recommend you to use a custom environment like `conda` or `virtualenv`. | ||
|
||
Usually, when you run different AI models, they require specific dependencies that sometimes conflict with each other. This is particularly true in the whisper case - from `requirements.txt`, you may notice that there are quite a bit of specific version requirements. | ||
|
||
This is where having a separate service like Lepton becomes super useful: we can create a python environment (using e.g. conda or virtualenv), installed the required dependencies, run the photon as a web service, and then in the regular python environment, simply call the web service as if we were using a regular python function. Comparing to some apparent choices: | ||
- unlike a single python environment, we don't need to resolve version conflicts of different algorithms; | ||
- unlike packing everything in a separate opaque container image, we are much more lightweighted: only a python environment and dependencies are needed. | ||
|
||
## Running with a custom environment. | ||
|
||
We recommend you use conda or virtualenv to start a whisper-specific environment. For example, if you use conda, it's easy to do: | ||
|
||
```shell | ||
# pick a python version of your favorite | ||
conda create -n whisperx python=3.10 | ||
conda activate whisperx | ||
``` | ||
|
||
After that, install lepton [per the installation instruction](https://www.lepton.ai/docs/overview/quickstart#1-installation), and install the required dependencies of this demo via: | ||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
After this, you can launch whisperx like: | ||
```shell | ||
# Set your huggingface token. This is required to obtain the respective models. | ||
export HUGGING_FACE_HUB_TOKEN="replace-with-your-own-token" | ||
python main.py | ||
``` | ||
|
||
It will download the paramaters and start the server. After that, use the regular python client to access the model: | ||
```python | ||
from leptonai.client import Client, local | ||
c = Client(local()) | ||
``` | ||
|
||
and invoke transcription or translation as follows: | ||
```python | ||
>> c.run(filename="assets/thequickbrownfox.wav") | ||
[{'start': 0.028, | ||
'end': 2.06, | ||
'text': ' A quick brown fox jumps over the lazy dog.', | ||
'words': [{'word': 'A', 'start': 0.028, 'end': 0.068, 'score': 0.5}, | ||
{'word': 'quick', 'start': 0.109, 'end': 0.31, 'score': 0.995}, | ||
{'word': 'brown', | ||
'start': 0.35, | ||
'end': 0.571, | ||
'score': 0.849, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'fox', | ||
'start': 0.612, | ||
'end': 0.853, | ||
'score': 0.897, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'jumps', | ||
'start': 0.893, | ||
'end': 1.175, | ||
'score': 0.867, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'over', | ||
'start': 1.255, | ||
'end': 1.416, | ||
'score': 0.648, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'the', | ||
'start': 1.456, | ||
'end': 1.517, | ||
'score': 0.998, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'lazy', | ||
'start': 1.557, | ||
'end': 1.839, | ||
'score': 0.922, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'dog.', | ||
'start': 1.859, | ||
'end': 2.06, | ||
'score': 0.998, | ||
'speaker': 'SPEAKER_00'}], | ||
'speaker': 'SPEAKER_00'}] | ||
``` | ||
|
||
## Running with Lepton | ||
|
||
The above example runs on the local machine. If your machine does not have a public facing IP, or more commonly, you want a stable server environment to host your model - then running on the Lepton cloud platform is the best option. To run it on Lepton, you can simply create a photon and push it to the cloud: | ||
|
||
```shell | ||
lep login | ||
lep photon create -n whisperx -m main.py | ||
lep photon push -n whisperx | ||
# An A10 machine is usually big enough to run the large-v2 model. | ||
lep photon run -n whisperx --resource-shape gpu.a10 | ||
``` | ||
|
||
After that, you can use the `lep deployment status` to obtain the public address of the photon, and use the same slack app to connect to it: | ||
```shell | ||
>> lep deployment status -n whisperx | ||
Created at: 2023-08-09 20:24:48 | ||
Created at: 2023-08-16 11:08:56 | ||
Photon ID: whisperx-bsip0d8q | ||
State: Running | ||
Endpoint: https://latest-whisperx.cloud.lepton.ai | ||
Is Public: No | ||
Replicas List: | ||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┓ | ||
┃ replica id ┃ status ┃ message ┃ | ||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━┩ | ||
│ whisperx-5ddc79f555-l59cj │ Ready │ (empty) │ | ||
└───────────────────────────┴────────┴─────────┘ | ||
1 out of 1 replicas ready. | ||
``` | ||
|
||
To access the model, we can create a client similar to the local case, simply replace `local()` with the workspace, deployment name, and token. Also, since we are running remote now, we will need to upload the audio files. This is done by calling the `run_updload` path: | ||
```python | ||
>> from leptonai.client import Client | ||
>> from leptonai.photon import FileParam | ||
>> c = Client("YOUR_WORKSPACE_NAME", "whisperx", token="YOUR_TOKEN") | ||
>> c.run_upload(upload_file=FileParam(open("assets/thequickbrownfox.mp3", "rb"))) | ||
[{'start': 0.028, | ||
'end': 2.06, | ||
'text': ' A quick brown fox jumps over the lazy dog.', | ||
'words': [{'word': 'A', 'start': 0.028, 'end': 0.068, 'score': 0.5}, | ||
{'word': 'quick', 'start': 0.109, 'end': 0.31, 'score': 0.995}, | ||
{'word': 'brown', | ||
'start': 0.35, | ||
'end': 0.571, | ||
'score': 0.849, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'fox', | ||
'start': 0.612, | ||
'end': 0.853, | ||
'score': 0.897, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'jumps', | ||
'start': 0.893, | ||
'end': 1.175, | ||
'score': 0.867, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'over', | ||
'start': 1.255, | ||
'end': 1.416, | ||
'score': 0.648, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'the', | ||
'start': 1.456, | ||
'end': 1.517, | ||
'score': 0.998, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'lazy', | ||
'start': 1.557, | ||
'end': 1.839, | ||
'score': 0.922, | ||
'speaker': 'SPEAKER_00'}, | ||
{'word': 'dog.', | ||
'start': 1.859, | ||
'end': 2.06, | ||
'score': 0.998, | ||
'speaker': 'SPEAKER_00'}], | ||
'speaker': 'SPEAKER_00'}] | ||
``` | ||
|
||
Unlike local deployment, running on the Lepton cloud platform comes with a series of advantages, especially in the whisperx case: | ||
- You do not need to worry about reproducible software environment. The photon is guaranteed to run on the same environment as you created it. | ||
- Scaling is easier - you can simply increase the number of replicas if you need more capacity. | ||
- Automatic fault tolerance - if the photon crashes, it will be automatically restarted. | ||
|
||
Happy building! |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import os | ||
import sys | ||
import tempfile | ||
from typing import List, Dict, Optional | ||
|
||
import numpy as np | ||
|
||
import whisperx | ||
|
||
from loguru import logger | ||
|
||
from leptonai.photon import Photon, FileParam, HTTPException | ||
|
||
|
||
class WhisperXPhoton(Photon): | ||
""" | ||
A WhisperX photon that serves the [WhisperX](https://github.com/m-bain/whisperX) model. | ||
The photon exposes two endpoints: `/run` and `/run_upload` that deals with files/urls | ||
and uploaded contents respectively. See the docs of each for details. | ||
""" | ||
|
||
requirement_dependency = [ | ||
"leptonai", | ||
"torch", | ||
"torchaudio", | ||
"git+https://github.com/m-bain/whisperx.git", | ||
] | ||
|
||
system_dependencies = ["ffmpeg"] | ||
|
||
def init(self): | ||
logger.info("Initializing WhisperXPhoton") | ||
self.hf_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", None) | ||
if not self.hf_token: | ||
logger.warning( | ||
"Please set the environment variable HUGGING_FACE_HUB_TOKEN." | ||
) | ||
sys.exit(1) | ||
self.device = "cuda" | ||
compute_type = "float16" | ||
self._model = whisperx.load_model( | ||
"large-v2", self.device, compute_type=compute_type | ||
) | ||
|
||
self._language_code = "en" | ||
# 2. Align whisper output | ||
model_a, metadata = whisperx.load_align_model( | ||
language_code=self._language_code, device=self.device | ||
) | ||
self._diarize_model = whisperx.DiarizationPipeline( | ||
use_auth_token=self.hf_token, device=self.device | ||
) | ||
|
||
def _run_audio( | ||
self, | ||
audio: np.ndarray, | ||
batch_size: int = 4, | ||
min_speakers: Optional[int] = None, | ||
max_speakers: Optional[int] = None, | ||
) -> List[Dict]: | ||
""" | ||
The main function that is called by the others. | ||
""" | ||
result = self._model.transcribe(audio, batch_size=batch_size) | ||
# print(result["segments"]) # before alignment | ||
|
||
if self._language_code != result["language"]: | ||
self._model_a, self._metadata = whisperx.load_align_model( | ||
language_code=result["language"], device=self.device | ||
) | ||
self._language_code = result["language"] | ||
|
||
# 2. Align whisper output | ||
model_a, metadata = whisperx.load_align_model( | ||
language_code=result["language"], device=self.device | ||
) | ||
result = whisperx.align( | ||
result["segments"], | ||
model_a, | ||
metadata, | ||
audio, | ||
self.device, | ||
return_char_alignments=False, | ||
) | ||
|
||
# print(result["segments"]) # after alignment | ||
|
||
# add min/max number of speakers if known | ||
diarize_segments = self._diarize_model(audio) | ||
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) | ||
|
||
result = whisperx.assign_word_speakers(diarize_segments, result) | ||
print(diarize_segments) | ||
# return result["segments"] # segments are now assigned speaker IDs | ||
return result["segments"] | ||
|
||
@Photon.handler( | ||
example={ | ||
"filename": ( | ||
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac" | ||
), | ||
} | ||
) | ||
def run( | ||
self, | ||
filename: str, | ||
batch_size: int = 4, | ||
min_speakers: Optional[int] = None, | ||
max_speakers: Optional[int] = None, | ||
) -> List[Dict]: | ||
""" | ||
Runs transcription, alignment, and diarization for the input. | ||
- Inputs: | ||
- filename: a url containing the audio file, or a local file path if running | ||
locally. | ||
- batch_size(optional): the batch size to run whisperx inference. | ||
- min_speakers(optional): the hint for minimum number of speakers for diarization. | ||
- max_speakers(optional): the hint for maximum number of speakers for diarization. | ||
- Returns: | ||
- result: a list of dictionary, each containing one classified segment. Each | ||
segment is a dictionary containing the following keys: `start` and `end` | ||
specifying the start and end time of the segment in seconds, `text` as | ||
the recognized text, `words` that contains segmented words and corresponding | ||
speaker IDs. | ||
- 404: if the file cannot be loaded. | ||
- 500: if internal error occurs. | ||
""" | ||
try: | ||
audio = whisperx.load_audio(filename) | ||
except Exception as e: | ||
raise HTTPException( | ||
status_code=404, | ||
detail=( | ||
f"Cannot load audio at {filename}. Detailed error message: {str(e)}" | ||
), | ||
) | ||
return self._run_audio(audio, batch_size, min_speakers, max_speakers) | ||
|
||
@Photon.handler( | ||
example={ | ||
"upload_file": ( | ||
"(please use python) FileParam(open('path/to/your/file.wav', 'rb'))" | ||
), | ||
} | ||
) | ||
def run_upload( | ||
self, | ||
upload_file: FileParam, | ||
batch_size: int = 4, | ||
min_speakers: Optional[int] = None, | ||
max_speakers: Optional[int] = None, | ||
) -> List[Dict]: | ||
""" | ||
Runs transcription, alignment, and diarization for the input. | ||
Everything is the same as the `/run` path, except that the input is uploaded | ||
as a file. If you are using the lepton python client, you can achieve so by | ||
from leptonai.photon import FileParam | ||
from leptonai.client import Client | ||
client = Client(PUT_YOUR_SERVER_INFO_HERE) | ||
client.run_upload(upload_file=FileParam(open("path/to/your/file.wav", "rb"))) | ||
For more details, refer to `/run`. | ||
""" | ||
logger.info(f"upload_file: {upload_file}") | ||
# Whisper at this moment only reads contents from file, so we will have to | ||
# write it to a temporary file | ||
tmpfile = tempfile.NamedTemporaryFile() | ||
with open(tmpfile.name, "wb") as f: | ||
f.write(upload_file.file.read()) | ||
f.flush() | ||
logger.info(f"tmpfile: {tmpfile.name}") | ||
try: | ||
audio = whisperx.load_audio(tmpfile.name) | ||
except Exception as e: | ||
logger.info(f"encountered error. returning 500. Detailed: {e}") | ||
raise HTTPException( | ||
status_code=500, | ||
detail=( | ||
"Cannot load audio with uploaded content. Detailed error" | ||
f" message: {str(e)}" | ||
), | ||
) | ||
logger.info("Started running WhisperX") | ||
ret = self._run_audio(audio, batch_size, min_speakers, max_speakers) | ||
# remove temporary file | ||
tmpfile.close() | ||
return ret | ||
|
||
|
||
if __name__ == "__main__": | ||
p = WhisperXPhoton() | ||
p.launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
leptonai | ||
torch | ||
torchaudio | ||
git+https://github.com/m-bain/whisperx.git |