Skip to content

Commit

Permalink
兼容gpt sovits官方整合包0706
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikaros-521 committed Jul 21, 2024
1 parent 213d74f commit cd088d6
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,5 @@ AIHubMix: [aihubmix.com](https://aihubmix.com/register?aff=1BMI) ———— O
| 用户信息 | 名人名言 |
|--------|------|
| QQ:750359376 | 笑死,连点开源精神都没有 |
| QQ:378198682 | 【散播谣言】 |

7 changes: 7 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,13 @@
"return_fragment": false,
"fragment_interval": "0.3"
},
"api_0706": {
"refer_wav_path": "F:\\GPT-SoVITS-0304\\output\\slicer_opt\\smoke1.wav",
"text_language": "中文",
"prompt_text": "整整策划了半年了,终于现在有结果了",
"prompt_language": "中文",
"cut_punc": ",。"
},
"webtts": {
"version": "1",
"api_ip_port": "http://127.0.0.1:8080",
Expand Down
7 changes: 7 additions & 0 deletions config.json.bak
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,13 @@
"return_fragment": false,
"fragment_interval": "0.3"
},
"api_0706": {
"refer_wav_path": "F:\\GPT-SoVITS-0304\\output\\slicer_opt\\smoke1.wav",
"text_language": "中文",
"prompt_text": "整整策划了半年了,终于现在有结果了",
"prompt_language": "中文",
"cut_punc": ",。"
},
"webtts": {
"version": "1",
"api_ip_port": "http://127.0.0.1:8080",
Expand Down
229 changes: 229 additions & 0 deletions tests/test_gpt_sovits/api_0706.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import logging, json, aiohttp, os, traceback
import base64
import mimetypes
import websockets
import asyncio

async def gpt_sovits_api(data):


def file_to_data_url(file_path):
# 根据文件扩展名确定 MIME 类型
mime_type, _ = mimetypes.guess_type(file_path)

# 读取文件内容
with open(file_path, "rb") as file:
file_content = file.read()

# 转换为 Base64 编码
base64_encoded_data = base64.b64encode(file_content).decode('utf-8')

# 构造完整的 Data URL
return f"data:{mime_type};base64,{base64_encoded_data}"

async def websocket_client(data_json):
try:
async with websockets.connect(data["ws_ip_port"]) as websocket:
# 设置最大连接时长(例如 30 秒)
return await asyncio.wait_for(websocket_client_logic(websocket, data_json), timeout=30)
except asyncio.TimeoutError:
logging.error("gpt_sovits WebSocket连接超时")
return None

async def websocket_client_logic(websocket, data_json):
async for message in websocket:
logging.debug(f"Received message: {message}")

# 解析收到的消息
data = json.loads(message)
# 检查是否是预期的消息
if "msg" in data:
if data["msg"] == "send_hash":
# 发送响应消息
response = json.dumps({"session_hash":"3obpzfqql7f","fn_index":3})
await websocket.send(response)
logging.debug(f"Sent message: {response}")
elif data["msg"] == "send_data":
# audio_path = "F:\\GPT-SoVITS\\raws\\ikaros\\1.wav"
audio_path = data_json["ref_audio_path"]

# 发送响应消息
response = json.dumps(
{
"session_hash":"3obpzfqql7f",
"fn_index":3,
"data":[
{
"data": file_to_data_url(audio_path),
"name": os.path.basename(audio_path)
},
data_json["prompt_text"],
data_json["prompt_language"],
data_json["content"],
data_json["language"],
data_json["cut"]
]
}
)
await websocket.send(response)
logging.debug(f"Sent message: {response}")
elif data["msg"] == "process_completed":
return data["output"]["data"][0]["name"]

try:
logging.debug(f"data={data}")

if data["type"] == "gradio":
# 调用函数并等待结果
voice_tmp_path = await websocket_client(data)

# if voice_tmp_path:
# new_file_path = self.common.move_file(voice_tmp_path, os.path.join(self.audio_out_path, 'gpt_sovits_' + self.common.get_bj_time(4)), 'gpt_sovits_' + self.common.get_bj_time(4))

new_file_path = 'gpt_sovits_.wav'

return new_file_path
elif data["type"] == "api":
try:
data_json = {
"refer_wav_path": data["ref_audio_path"],
"prompt_text": data["prompt_text"],
"prompt_language": data["prompt_language"],
"text": data["content"],
"text_language": data["language"]
}

async with aiohttp.ClientSession() as session:
async with session.post(data["api_ip_port"], json=data_json, timeout=30) as response:
response = await response.read()

file_name = 'gpt_sovits_.wav'

voice_tmp_path = file_name

# voice_tmp_path = self.common.get_new_audio_path(self.audio_out_path, file_name)

with open(voice_tmp_path, 'wb') as f:
f.write(response)

return voice_tmp_path
except aiohttp.ClientError as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits请求失败: {e}')
except Exception as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits未知错误: {e}')
elif data["type"] == "webtts":
try:
# 使用字典推导式构建 params 字典,只包含非空字符串的值
params = {
key: value
for key, value in data["webtts"].items()
if value != ""
if key != "api_ip_port"
}

# params["speed"] = self.get_random_float(params["speed"])
params["text"] = data["content"]

async with aiohttp.ClientSession() as session:
async with session.get(data["webtts"]["api_ip_port"], params=params, timeout=30) as response:
response = await response.read()

file_name = 'gpt_sovits_.wav'

voice_tmp_path = file_name

# voice_tmp_path = self.common.get_new_audio_path(self.audio_out_path, file_name)

with open(voice_tmp_path, 'wb') as f:
f.write(response)

return voice_tmp_path
except aiohttp.ClientError as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits请求失败: {e}')
except Exception as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits未知错误: {e}')
elif data["type"] == "api_0706":
try:
data_json = {
"refer_wav_path": data["ref_audio_path"],
"prompt_text": data["prompt_text"],
"prompt_language": data["prompt_language"],
"text": data["content"],
"text_language": data["language"]
}

async with aiohttp.ClientSession() as session:
async with session.post(data["api_ip_port"], json=data_json, timeout=30) as response:
response = await response.read()

file_name = 'gpt_sovits_.wav'

voice_tmp_path = file_name

# voice_tmp_path = self.common.get_new_audio_path(self.audio_out_path, file_name)

with open(voice_tmp_path, 'wb') as f:
f.write(response)

return voice_tmp_path
except aiohttp.ClientError as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits请求失败: {e}')
except Exception as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits未知错误: {e}')

except Exception as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits未知错误,请检查您的gpt_sovits推理是否启动/配置是否正确,报错内容: {e}')

return None


async def gpt_sovits_set_model(data):
from urllib.parse import urljoin

if data["type"] == "api":
try:
data_json = {
"gpt_model_path": data["gpt_model_path"],
"sovits_model_path": data["sovits_model_path"]
}

API_URL = urljoin(data["api_ip_port"], '/set_model')

async with aiohttp.ClientSession() as session:
async with session.post(API_URL, json=data_json, timeout=30) as response:
response = await response.read()

print(response)

return response
except aiohttp.ClientError as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits请求失败: {e}')
except Exception as e:
logging.error(traceback.format_exc())
logging.error(f'gpt_sovits未知错误: {e}')


if __name__ == '__main__':
# 配置日志输出格式
logging.basicConfig(
level=logging.DEBUG, # 设置日志级别,可以根据需求调整
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

data = {
"type": "api",
"api_ip_port": "http://127.0.0.1:9880",
"gpt_model_path": "F:\GPT-SoVITS\GPT_weights\ikaros-e15.ckpt",
"sovits_model_path": "F:\GPT-SoVITS\SoVITS_weights\ikaros_e8_s280.pth"
}

asyncio.run(gpt_sovits_set_model(data))
5 changes: 5 additions & 0 deletions utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,9 @@ async def tts_handle(self, message):
else:
message["data"]["api_0322"]["text_lang"] = "中文" # 无法识别出语言代码时的默认值

if message["data"]["api_0706"]["text_language"] == "自动识别":
message["data"]["api_0706"]["text_language"] = "auto"

data = {
"type": message["data"]["type"],
"gradio_ip_port": message["data"]["gradio_ip_port"],
Expand All @@ -1019,6 +1022,7 @@ async def tts_handle(self, message):
"language": language,
"cut": message["data"]["cut"],
"api_0322": message["data"]["api_0322"],
"api_0706": message["data"]["api_0706"],
"webtts": message["data"]["webtts"],
"content": message["content"]
}
Expand Down Expand Up @@ -1956,6 +1960,7 @@ async def audio_synthesis_use_local_config(self, content, audio_synthesis_type="
"language": language,
"cut": self.config.get("gpt_sovits", "cut"),
"api_0322": self.config.get("gpt_sovits", "api_0322"),
"api_0706": self.config.get("gpt_sovits", "api_0706"),
"webtts": self.config.get("gpt_sovits", "webtts"),
"content": content
}
Expand Down
19 changes: 19 additions & 0 deletions utils/audio_handle/my_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,25 @@ async def websocket_client_logic(websocket, data_json):
"fragment_interval":data["api_0322"]["fragment_interval"],
}

return await self.download_audio("gpt_sovits", data["api_ip_port"], self.timeout, "post", None, data_json)
except aiohttp.ClientError as e:
logger.error(traceback.format_exc())
logger.error(f'gpt_sovits请求失败: {e}')
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f'gpt_sovits未知错误: {e}')
elif data["type"] == "api_0706":
try:

data_json = {
"text": data["content"],
"refer_wav_path": data["api_0706"]["refer_wav_path"],
"text_language": data["api_0706"]["text_language"],
"prompt_text": data["api_0706"]["prompt_text"],
"prompt_language": data["api_0706"]["prompt_language"],
"cut_punc": data["api_0706"]["cut_punc"],
}

return await self.download_audio("gpt_sovits", data["api_ip_port"], self.timeout, "post", None, data_json)
except aiohttp.ClientError as e:
logger.error(traceback.format_exc())
Expand Down
35 changes: 33 additions & 2 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,12 @@ def common_textarea_handle(content):
config_data["gpt_sovits"]["api_0322"]["split_bucket"] = switch_gpt_sovits_api_0322_split_bucket.value
config_data["gpt_sovits"]["api_0322"]["return_fragment"] = switch_gpt_sovits_api_0322_return_fragment.value

config_data["gpt_sovits"]["api_0706"]["refer_wav_path"] = input_gpt_sovits_api_0706_refer_wav_path.value
config_data["gpt_sovits"]["api_0706"]["prompt_text"] = input_gpt_sovits_api_0706_prompt_text.value
config_data["gpt_sovits"]["api_0706"]["prompt_language"] = select_gpt_sovits_api_0706_prompt_language.value
config_data["gpt_sovits"]["api_0706"]["text_language"] = select_gpt_sovits_api_0706_text_language.value
config_data["gpt_sovits"]["api_0706"]["cut_punc"] = input_gpt_sovits_api_0706_cut_punc.value

config_data["gpt_sovits"]["webtts"]["version"] = select_gpt_sovits_webtts_version.value
config_data["gpt_sovits"]["webtts"]["api_ip_port"] = input_gpt_sovits_webtts_api_ip_port.value
config_data["gpt_sovits"]["webtts"]["spk"] = input_gpt_sovits_webtts_spk.value
Expand Down Expand Up @@ -4987,7 +4993,7 @@ def vits_get_speaker_id():
with ui.row():
select_gpt_sovits_type = ui.select(
label='API类型',
options={'gradio':'gradio旧版', 'gradio_0322':'gradio_0322', 'api':'api', 'api_0322':'api_0322', 'webtts':'WebTTS'},
options={'api':'api', 'api_0322':'api_0322', 'api_0706':'api_0706', 'webtts':'WebTTS', 'gradio':'gradio旧版', 'gradio_0322':'gradio_0322'},
value=config.get("gpt_sovits", "type")
).style("width:100px;")
input_gpt_sovits_gradio_ip_port = ui.input(
Expand Down Expand Up @@ -5086,7 +5092,32 @@ def vits_get_speaker_id():
input_gpt_sovits_api_0322_fragment_interval = ui.input(label='分段间隔(秒)', value=config.get("gpt_sovits", "api_0322", "fragment_interval"), placeholder='fragment_interval').style("width:100px;")
switch_gpt_sovits_api_0322_split_bucket = ui.switch('split_bucket', value=config.get("gpt_sovits", "api_0322", "split_bucket")).style(switch_internal_css)
switch_gpt_sovits_api_0322_return_fragment = ui.switch('return_fragment', value=config.get("gpt_sovits", "api_0322", "return_fragment")).style(switch_internal_css)


with ui.card().style(card_css):
ui.label("api_0706")
with ui.row():
input_gpt_sovits_api_0706_refer_wav_path = ui.input(label='参考音频路径', value=config.get("gpt_sovits", "api_0706", "refer_wav_path"), placeholder='参考音频路径,建议填绝对路径').style("width:300px;")
input_gpt_sovits_api_0706_prompt_text = ui.input(label='参考音频的文本', value=config.get("gpt_sovits", "api_0706", "prompt_text"), placeholder='参考音频的文本').style("width:200px;")
select_gpt_sovits_api_0706_prompt_language = ui.select(
label='参考音频的语种',
options={'中文':'中文', '日文':'日文', '英文':'英文'},
value=config.get("gpt_sovits", "api_0706", "prompt_language")
).style("width:150px;")
select_gpt_sovits_api_0706_text_language = ui.select(
label='需要合成的语种',
options={
'自动识别':'自动识别',
'中文':'中文',
'日文':'日文',
'英文':'英文',
'中英混合': '中英混合',
'日英混合': '日英混合',
'多语种混合': '多语种混合',
},
value=config.get("gpt_sovits", "api_0706", "text_language")
).style("width:150px;")
input_gpt_sovits_api_0706_cut_punc = ui.input(label='文本切分', value=config.get("gpt_sovits", "api_0706", "cut_punc"), placeholder='文本切分符号设定, 符号范围,.;?!、,。?!;:…').style("width:200px;")


with ui.card().style(card_css):
ui.label("WebTTS相关配置")
Expand Down

0 comments on commit cd088d6

Please sign in to comment.