Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
chocolatedesue committed Jan 10, 2023
1 parent 01d4e55 commit b21f98a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
5 changes: 1 addition & 4 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

class HParams():
def __init__(self, **kwargs):
for k, v in kwargs.items():
Expand Down Expand Up @@ -28,7 +29,3 @@ def __contains__(self, key):

def __repr__(self):
return self.__dict__.__repr__()


MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJlRFc1djM5Q1MzOUhWRGc/root/content"
CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content"
21 changes: 15 additions & 6 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path

from loguru import logger
from app import CONFIG_URL, MODEL_URL
# from app import CONFIG_URL, MODEL_URL
from app.util import get_hparams_from_file, get_paths, time_it
import requests
from tqdm.auto import tqdm
Expand All @@ -12,19 +12,29 @@
import threading


MODEL_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJmckZWcGdCR0xxLWJmU28/root/content"
CONFIG_URL = r"https://api.onedrive.com/v1.0/shares/u!aHR0cHM6Ly8xZHJ2Lm1zL3UvcyFBdG53cTVRejJnLTJhNEJ3enhhUHpqNE5EZWc/root/content"



class Config:
hps: dict = None
pattern: Pattern = None
# symbol_to_id:dict = None
speaker_choices: list = None
ort_sess: ort.InferenceSession = None
model_is_ok: bool = False

@classmethod
def init(cls):

# logger.add(
# "vits_infer.log", rotation="10 MB", encoding="utf-8", enqueue=True, retention="30 days"
# )

brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"]
cls.pattern = re.compile('|'.join(map(re.escape, brackets)))

dir_path = Path(__file__).parent.absolute() / ".model"
dir_path.mkdir(
parents=True, exist_ok=True
Expand All @@ -50,10 +60,9 @@ def init(cls):
cls.setup_config(str(config_path))
cls.setup_model(str(model_path))

brackets = ['(', '[', '『', '「', '【', ")", "】", "]", "』", "」", ")"]
cls.pattern = re.compile('|'.join(map(re.escape, brackets)))

@classmethod
@logger.catch
@time_it
def setup_model(cls, model_path: str):
import numpy as np
cls.ort_sess = ort.InferenceSession(model_path)
Expand Down Expand Up @@ -82,6 +91,8 @@ def setup_model(cls, model_path: str):
}
cls.ort_sess.run(None, ort_inputs)

cls.model_is_ok = True

logger.info(
f"model init done with model path {model_path}"
)
Expand All @@ -97,8 +108,6 @@ def setup_config(cls, config_path: str):
)

@classmethod
@time_it
@logger.catch
def pdownload(cls, url: str, save_path: str, chunk_size: int = 8192):
# copy from https://github.com/tqdm/tqdm/blob/master/examples/tqdm_requests.py
file_size = int(requests.head(url).headers["Content-Length"])
Expand Down
3 changes: 2 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def tts_fn(text, speaker_id, speed=1.0):
if len(text) > 300:
return "Error: Text is too long, please down it to 300 characters", None

if Config.ort_sess is None:
if not Config.model_is_ok:
return "Error: model not loaded, please wait for a while or look the log", None

seq = text_to_seq(text)
Expand Down Expand Up @@ -78,6 +78,7 @@ def set_gradio_view():
tts_submit.click(tts_fn, inputs=inputs, outputs=outputs)

app.queue(concurrency_count=3)
gr.close_all()
app.launch(server_name='0.0.0.0', show_api=False,
share=False, server_port=7860)

Expand Down

0 comments on commit b21f98a

Please sign in to comment.