Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved usability for txt2audio ui #96

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,6 @@ cython_debug/

*.ckpt
*.wav
wandb/*
wandb/*
models/*
outputs/*
3 changes: 3 additions & 0 deletions config/txt2audio.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"model_selected": ""
}
3 changes: 2 additions & 1 deletion run_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main(args):
model_half=args.model_half
)
interface.queue()
interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None)
interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None, inbrowser=args.inbrowser if args.inbrowser is not None else False)

if __name__ == "__main__":
import argparse
Expand All @@ -27,5 +27,6 @@ def main(args):
parser.add_argument('--username', type=str, help='Gradio username', required=False)
parser.add_argument('--password', type=str, help='Gradio password', required=False)
parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False)
parser.add_argument('--inbrowser', action='store_true', help='Open browser on launch', required=False)
args = parser.parse_args()
main(args)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'vector-quantize-pytorch==1.9.14',
'wandb==0.15.4',
'webdataset==0.2.48',
'x-transformers<1.27.0'
'x-transformers<1.27.0',
'pytaglib==3.0.0'
],
)
136 changes: 136 additions & 0 deletions stable_audio_tools/data/txt2audio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import taglib
import os
from datetime import datetime
import platform
import subprocess
import json

def set_selected_model(model_name):
if model_name in [data["name"] for data in get_models_data()]:
config = get_config()
config["model_selected"] = model_name
with open("config/txt2audio.json", "w") as file:
json.dump(config, file, indent=4)
file.write('\n')

def get_config():
with open("config/txt2audio.json") as file:
return json.load(file)

def get_models_name():
return [model["name"] for model in get_models_data()]

def get_models_data():
models = []
file_types = ['.ckpt', '.safetensors', '.pth']
for file in os.listdir("models/"):
_file = os.path.splitext(file)
config_path = f"models/{_file[0]}.json"
if _file[1] in file_types and os.path.isfile(config_path):
models.append({"name": _file[0], "path": f"models/{file}", "config_path": config_path})
return models

def open_outputs_path():
outputs_dir = "outputs/"
outputs = outputs_dir + datetime.now().strftime('%Y-%m-%d')

if not os.path.isdir(outputs):
if not os.path.isdir(outputs_dir):
return
else:
outputs = outputs_dir

outputs = os.path.abspath(outputs)
if platform.system() == "Windows":
os.startfile(outputs)
elif platform.system() == "Darwin":
subprocess.Popen(["open", outputs])
elif "microsoft-standard-WSL2" in platform.uname().release:
subprocess.Popen(["wsl-open", outputs])
else:
subprocess.Popen(["xdg-open", outputs])

def create_output_path(suffix):
outputs = f"outputs/{datetime.now().strftime('%Y-%m-%d')}"
count = 0

if os.path.isdir(outputs):
counts = [os.path.splitext(file)[0].split('-')[0] for file in os.listdir(outputs) if file.endswith(".wav")]
count = max([int(i) for i in counts if i.isnumeric()]) + 1
else:
os.makedirs(outputs)

return f"{outputs}/{'{:05d}'.format(count)}-{suffix}.wav"

def get_generation_data(file):
with taglib.File(file) as sound:
if len(sound.tags) != 1:
return None

data = sound.tags["TITLE"]

if len(data) != 12:
return None
if data[0] == "None":
data[0] = ""
if data[1] == "None":
data[1] = ""
if data[5] == "None":
data[5] = 0

for i in range(2, 8):
data[i] = int(data[i])

for i in range(9, 12):
data[i] = float(data[i])

data[4] = float(data[4])

return data

def save_generation_data(sound_path, prompt, negative_prompt, seconds_start, seconds_total, steps, preview_every, cfg_scale, seed, sampler_type, sigma_min, sigma_max, cfg_rescale):
if prompt == "":
prompt = "None"
if negative_prompt == "":
negative_prompt = "None"

with taglib.File(sound_path, save_on_exit=True) as sound:
sound.tags["TITLE"] = [
prompt,
negative_prompt,
str(seconds_start),
str(seconds_total),
str(steps),
str(preview_every),
str(cfg_scale),
str(seed),
str(sampler_type),
str(sigma_min),
str(sigma_max),
str(cfg_rescale)]

def txt2audio_css():
return """
div.svelte-sa48pu>*, div.svelte-sa48pu>.form>* {
flex: 1 1 0%;
flex-wrap: wrap;
min-width: min(40px, 100%);
}

#refresh_btn {
padding: 0px;
}

#selected_model_items div.svelte-1sk0pyu div.wrap.svelte-1sk0pyu div.wrap-inner.svelte-1sk0pyu div.secondary-wrap.svelte-1sk0pyu input.border-none.svelte-1sk0pyu {
margin: 0px;
}

#prompt_options {
flex-wrap: nowrap;
height: 40px;
}

#selected_model_container {
gap: 3px;
}
"""
88 changes: 74 additions & 14 deletions stable_audio_tools/interface/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
model = None
sample_rate = 32000
sample_size = 1920000
model_is_half = None

def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
global model, sample_rate, sample_size
Expand Down Expand Up @@ -55,6 +56,32 @@ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pr

return model, model_config

def unload_model():
global model
del model
model = None
torch.cuda.empty_cache()
gc.collect()

def txt2audio_change_model(model_name):
from stable_audio_tools.data.txt2audio_utils import get_models_data, set_selected_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for model_data in get_models_data():
if model_data["name"] == model_name:
unload_model()
set_selected_model(model_name)
model_config = get_model_config_from_path(model_data["config_path"])
load_model(model_config, model_data["path"], model_half=model_is_half, device=device)
return model_name

def get_model_config_from_path(model_config_path):
if model_config_path is not None:
# Load config from json file
with open(model_config_path) as f:
return json.load(f)
else:
return None

def generate_cond(
prompt,
negative_prompt=None,
Expand Down Expand Up @@ -163,6 +190,7 @@ def progress_callback(callback_info):
else:
mask_args = None

seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
# Do the audio generation
audio = generate_diffusion_cond(
model,
Expand All @@ -188,12 +216,16 @@ def progress_callback(callback_info):
# Convert to WAV file
audio = rearrange(audio, "b d n -> d (b n)")
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", audio, sample_rate)

from stable_audio_tools.data.txt2audio_utils import create_output_path, save_generation_data
output_path = create_output_path(seed)
torchaudio.save(output_path, audio, sample_rate)
save_generation_data(output_path, prompt, negative_prompt, seconds_start, seconds_total, steps, preview_every, cfg_scale, seed, sampler_type, sigma_min, sigma_max, cfg_rescale)

# Let's look at a nice spectrogram too
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)

return ("output.wav", [audio_spectrogram, *preview_images])
return (output_path, [audio_spectrogram, *preview_images])

def generate_uncond(
steps=250,
Expand Down Expand Up @@ -380,6 +412,10 @@ def create_sampling_ui(model_config, inpainting=False):
with gr.Column(scale=6):
prompt = gr.Textbox(show_label=False, placeholder="Prompt")
negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt")
with gr.Row(elem_id="prompt_options"):
clear_prompt = gr.Button('\U0001f5d1\ufe0f')
paste_generation_data = gr.Button('\u2199\ufe0f')
insert_generation_data = gr.File(label="Insert generation data from output.wav", file_types=[".wav"], scale=0)
generate_button = gr.Button("Generate", variant='primary', scale=1)

model_conditioning_config = model_config["model"].get("conditioning", None)
Expand Down Expand Up @@ -492,21 +528,49 @@ def create_sampling_ui(model_config, inpainting=False):

with gr.Column():
audio_output = gr.Audio(label="Output audio", interactive=False)
audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
send_to_init_button = gr.Button("Send to init audio", scale=1)
audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
with gr.Row():
open_outputs_folder = gr.Button("\U0001f4c1", scale=1)
send_to_init_button = gr.Button("Send to init audio", scale=1)
from stable_audio_tools.data.txt2audio_utils import open_outputs_path
open_outputs_folder.click(fn=open_outputs_path)
send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])

generate_button.click(fn=generate_cond,
from stable_audio_tools.data.txt2audio_utils import get_generation_data
paste_generation_data.click(fn=get_generation_data, inputs=[insert_generation_data], outputs=[prompt,
negative_prompt,
seconds_start_slider,
seconds_total_slider,
steps_slider,
preview_every_slider,
cfg_scale_slider,
seed_textbox,
sampler_type_dropdown,
sigma_min_slider,
sigma_max_slider,
cfg_rescale_slider])

clear_prompt.click(fn=lambda: ("", ""), outputs=[prompt, negative_prompt])

generate_button.click(fn=generate_cond,
inputs=inputs,
outputs=[
audio_output,
audio_spectrogram_output
],
api_name="generate")


def create_txt2audio_ui(model_config):
with gr.Blocks() as ui:
from stable_audio_tools.data.txt2audio_utils import txt2audio_css, get_models_name, get_config
with gr.Blocks(css=txt2audio_css()) as ui:
with gr.Column(elem_id="selected_model_container"):
gr.HTML('<label>Selected Model</label>', visible=True)
with gr.Row():
selected_model_dropdown = gr.Dropdown(get_models_name(), container=False, value=get_config()["model_selected"], interactive=True, scale=0, min_width=265, elem_id="selected_model_items")
selected_model_dropdown.change(fn=txt2audio_change_model, inputs=selected_model_dropdown, outputs=selected_model_dropdown)
refresh_models_button = gr.Button("\U0001f504", scale=0, elem_id="refresh_btn")
refresh_models_button.click(fn=lambda: gr.Dropdown(choices=get_models_name()), outputs=selected_model_dropdown)
gr.HTML('<div style="/* flex-grow: 2; */">', visible=True)
with gr.Tab("Generation"):
create_sampling_ui(model_config)
with gr.Tab("Inpainting"):
Expand Down Expand Up @@ -655,16 +719,12 @@ def create_lm_ui(model_config):
return ui

def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
global model_is_half
model_is_half = model_half

assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"

if model_config_path is not None:
# Load config from json file
with open(model_config_path) as f:
model_config = json.load(f)
else:
model_config = None

model_config = get_model_config_from_path(model_config_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)

Expand Down
54 changes: 54 additions & 0 deletions txt2audio_ui.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
@echo off
setlocal

set config_path='config/txt2audio.json'

for /f "delims=" %%i in ('python -c "import json; f=open(%config_path%); data=json.load(f); f.close(); print(data['model_selected'])"') do set model_selected=%%i

set models_path=.\models\

if not exist %models_path% (
mkdir %models_path%
echo No model found
pause
exit /b
)

if defined model_selected (
for %%i in (.ckpt .safetensors .pth) do (
if exist %models_path%%model_selected%%%i (
set model_path=%models_path%%model_selected%%%i
set model_name=%model_selected%
goto :model_found
)
)
)
set config_model_found=.
echo No model found in config file
echo Searching in models folder

for /R %models_path% %%f in (*.ckpt *.safetensors *.pth) do (
set model_path=%%~dpnxf
set model_name=%%~nf
goto :model_found
)
echo No model found
pause
exit /b

:model_found
echo Found model: %model_name%
set model_config_path=%models_path%%model_name%.json
if not exist %model_config_path% (
echo Model config not found.
pause
exit /b
)

if defined config_model_found (
python -c "import json; f=open(%config_path%); data=json.load(f); data['model_selected'] = '%model_name%'; f.close(); f=open(%config_path%, 'w'); json.dump(data, f, indent=4); f.write('\n'); f.close()"
)

call .\venv\Scripts\activate.bat
call python run_gradio.py --ckpt-path %model_path% --model-config %model_config_path% --inbrowser
pause