Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lock for tts main #56

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions advanced/tts/tts_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from io import BytesIO
import os
from typing import List, Optional, Union
from threading import Lock
from typing import List, Optional, Union, Dict

from loguru import logger
import torch
Expand Down Expand Up @@ -38,7 +39,7 @@ class Speaker(Photon):

MODEL_NAME = "tts_models/en/vctk/vits"
# Or, you can choose some other models
# MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v1"
#MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v1"

# If you want to load multiple models at the same time, you can put it here
# as a comma-separated string. For example:
Expand All @@ -53,21 +54,23 @@ def init(self):
"""
from TTS.api import TTS

self._models: dict[Union[str, None], TTS] = {}
self._models: Dict[Union[str, None], TTS] = {}
self._model_lock: Dict[Union[str, None], Lock] = {}

self.MODEL_NAME = os.environ.get("MODEL_NAME", self.MODEL_NAME).strip()

self.PRELOAD_MODELS = (
os.environ.get("PRELOAD_MODELS", self.PRELOAD_MODELS).strip().split(",")
)
self.PRELOAD_MODELS = [
m for m in os.environ.get("PRELOAD_MODELS", self.PRELOAD_MODELS).strip().split(",") if m
]
if self.MODEL_NAME not in self.PRELOAD_MODELS:
self.PRELOAD_MODELS.append(self.MODEL_NAME)

logger.info("Loading the model...")
for model_name in self.PRELOAD_MODELS:
self._models[model_name] = self._load_model(model_name)

self._model_lock[model_name] = Lock()
self._models[None] = self._models[self.MODEL_NAME]
self._model_lock[None] = self._model_lock[self.MODEL_NAME]
logger.debug("Model loaded.")

def _load_model(self, model_name: str):
Expand Down Expand Up @@ -122,12 +125,14 @@ def _tts(
logger.info(
f"Synthesizing '{text}' with language '{language}' and speaker '{speaker}'"
)
wav = self._models[model].tts(
text=text,
language=language, # type: ignore
speaker=speaker, # type: ignore
speaker_wav=speaker_wav,
)
# Many of the models might not be python thread safe, so we lock it.
with self._model_lock[model]:
wav = self._models[model].tts(
text=text,
language=language, # type: ignore
speaker=speaker, # type: ignore
speaker_wav=speaker_wav,
)
return wav

##########################################################################
Expand Down Expand Up @@ -259,7 +264,7 @@ def tts(
except Exception as e:
raise HTTPException(
status_code=500,
detail="Failed to synthesize speech.",
detail=f"Failed to synthesize speech. Details: {e}",
) from e


Expand Down
Loading