diff --git a/app/__init__.py b/app/__init__.py index 20630e9..2c8d25e 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,3 +1,4 @@ + class HParams(): def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -28,7 +29,3 @@ def __contains__(self, key): def __repr__(self): return self.__dict__.__repr__() - - -MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJlRFc1djM5Q1MzOUhWRGc/root/content" -CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content" diff --git a/app/config.py b/app/config.py index 2a0d3e6..e3e73ee 100644 --- a/app/config.py +++ b/app/config.py @@ -2,7 +2,7 @@ from pathlib import Path from loguru import logger -from app import CONFIG_URL, MODEL_URL +# from app import CONFIG_URL, MODEL_URL from app.util import get_hparams_from_file, get_paths, time_it import requests from tqdm.auto import tqdm @@ -12,12 +12,18 @@ import threading +MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJmckZWcGdCR0xxLWJmU28/root/content" +CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content" + + + class Config: hps: dict = None pattern: Pattern = None # symbol_to_id:dict = None speaker_choices: list = None ort_sess: ort.InferenceSession = None + model_is_ok: bool = False @classmethod def init(cls): @@ -25,6 +31,10 @@ def init(cls): # logger.add( # "vits_infer.log", rotation="10 MB", encoding="utf-8", enqueue=True, retention="30 days" # ) + + brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"] + cls.pattern = re.compile('|'.join(map(re.escape, brackets))) + dir_path = Path(__file__).parent.absolute() / ".model" dir_path.mkdir( parents=True, exist_ok=True @@ -50,10 +60,9 @@ def init(cls): cls.setup_config(str(config_path)) cls.setup_model(str(model_path)) - brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"] - cls.pattern = re.compile('|'.join(map(re.escape, brackets))) - @classmethod + @logger.catch + @time_it def setup_model(cls, model_path: str): import numpy as np cls.ort_sess = ort.InferenceSession(model_path) @@ -82,6 +91,8 @@ def setup_model(cls, model_path: str): } cls.ort_sess.run(None, ort_inputs) + cls.model_is_ok = True + logger.info( f"model init done with model path {model_path}" ) @@ -97,8 +108,6 @@ def setup_config(cls, config_path: str): ) @classmethod - @time_it - @logger.catch def pdownload(cls, url: str, save_path: str, chunk_size: int = 8192): # copy from https://github.com/tqdm/tqdm/blob/master/examples/tqdm_requests.py file_size = int(requests.head(url).headers["Content-Length"]) diff --git a/app/main.py b/app/main.py index 78ce179..4180995 100644 --- a/app/main.py +++ b/app/main.py @@ -26,7 +26,7 @@ def tts_fn(text, speaker_id, speed=1.0): if len(text) > 300: return "Error: Text is too long, please down it to 300 characters", None - if Config.ort_sess is None: + if not Config.model_is_ok: return "Error: model not loaded, please wait for a while or look the log", None seq = text_to_seq(text) @@ -78,6 +78,7 @@ def set_gradio_view(): tts_submit.click(tts_fn, inputs=inputs, outputs=outputs) app.queue(concurrency_count=3) + gr.close_all() app.launch(server_name='0.0.0.0', show_api=False, share=False, server_port=7860)