From 2c809ae5b72a0ef7edb640614a87b28e8b5c6c4b Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Tue, 7 Nov 2023 14:28:16 -0800 Subject: [PATCH] add lock for tts main --- advanced/tts/tts_main.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/advanced/tts/tts_main.py b/advanced/tts/tts_main.py index 4142143..67a3f43 100644 --- a/advanced/tts/tts_main.py +++ b/advanced/tts/tts_main.py @@ -1,6 +1,7 @@ from io import BytesIO import os -from typing import List, Optional, Union +from threading import Lock +from typing import List, Optional, Union, Dict from loguru import logger import torch @@ -38,7 +39,7 @@ class Speaker(Photon): MODEL_NAME = "tts_models/en/vctk/vits" # Or, you can choose some other models - # MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v1" + #MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v1" # If you want to load multiple models at the same time, you can put it here # as a comma-separated string. For example: @@ -53,21 +54,23 @@ def init(self): """ from TTS.api import TTS - self._models: dict[Union[str, None], TTS] = {} + self._models: Dict[Union[str, None], TTS] = {} + self._model_lock: Dict[Union[str, None], Lock] = {} self.MODEL_NAME = os.environ.get("MODEL_NAME", self.MODEL_NAME).strip() - self.PRELOAD_MODELS = ( - os.environ.get("PRELOAD_MODELS", self.PRELOAD_MODELS).strip().split(",") - ) + self.PRELOAD_MODELS = [ + m for m in os.environ.get("PRELOAD_MODELS", self.PRELOAD_MODELS).strip().split(",") if m + ] if self.MODEL_NAME not in self.PRELOAD_MODELS: self.PRELOAD_MODELS.append(self.MODEL_NAME) logger.info("Loading the model...") for model_name in self.PRELOAD_MODELS: self._models[model_name] = self._load_model(model_name) - + self._model_lock[model_name] = Lock() self._models[None] = self._models[self.MODEL_NAME] + self._model_lock[None] = self._model_lock[self.MODEL_NAME] logger.debug("Model loaded.") def _load_model(self, model_name: str): @@ -122,12 +125,14 @@ def _tts( logger.info( f"Synthesizing '{text}' with language '{language}' and speaker '{speaker}'" ) - wav = self._models[model].tts( - text=text, - language=language, # type: ignore - speaker=speaker, # type: ignore - speaker_wav=speaker_wav, - ) + # Many of the models might not be python thread safe, so we lock it. + with self._model_lock[model]: + wav = self._models[model].tts( + text=text, + language=language, # type: ignore + speaker=speaker, # type: ignore + speaker_wav=speaker_wav, + ) return wav ########################################################################## @@ -259,7 +264,7 @@ def tts( except Exception as e: raise HTTPException( status_code=500, - detail="Failed to synthesize speech.", + detail=f"Failed to synthesize speech. Details: {e}", ) from e