Skip to content

Commit

Permalink
update whisperx, add support of setting vad_method - silero or pyannote
Browse files Browse the repository at this point in the history
  • Loading branch information
Nyralei committed Oct 25, 2024
1 parent 6146651 commit 5d48a36
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ __pycache__/

# Custom
models
compose-dev.yaml
scripts
whisperx

# C extensions
*.so
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ fastapi==0.114.1
uvicorn==0.30.6
pydantic==2.9.1
aiohttp==3.10.5
whisperx @ git+https://github.com/Nyralei/whisperX.git@25c3fc61f795bf3fd40663098633fdba893eab0e
whisperx @ git+https://github.com/Nyralei/whisperX.git@a44b97fc42d9f2a102cb70e5abcad65d9fceb85c
ctranslate2==4.4.0
numpy==1.26.4
tqdm==4.66.5
Expand Down
8 changes: 8 additions & 0 deletions src/whisperx_api_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ class Device(str, Enum):
CUDA = "cuda"
AUTO = "auto"

class VadMethod(str, Enum):
SILERO = "silero"
PYANNOTE = "pyannote"

class WhisperConfig(BaseModel):
"""See https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/transcribe.py#L599."""

Expand All @@ -148,6 +152,10 @@ class WhisperConfig(BaseModel):
compute_type: Quantization = Field(default=Quantization.DEFAULT)
cpu_threads: int = Field(default=0)
num_workers: int = Field(default=1)
vad_method: VadMethod = Field(default=VadMethod.PYANNOTE)
vad_model: str = Field(default=None)
vad_options: dict = Field(default=None)


class Config(BaseSettings):
"""
Expand Down
3 changes: 3 additions & 0 deletions src/whisperx_api_server/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ async def transcribe(
compute_type=whispermodel.compute_type,
language=language,
asr_options=asr_options,
vad_model=config.whisper.vad_model,
vad_method=config.whisper.vad_method,
vad_options=config.whisper.vad_options,
model=whispermodel,
)
logger.info(f"Loading model took {time.time() - model_loading_start:.2f} seconds")
Expand Down

0 comments on commit 5d48a36

Please sign in to comment.