From d7d6839f0565f87d6426992157d552d2cd2a6259 Mon Sep 17 00:00:00 2001 From: LastDemon99 Date: Tue, 11 Jun 2024 04:11:03 -0400 Subject: [PATCH 1/6] Update .gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3e6aee6..5b5bbc3 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,6 @@ cython_debug/ *.ckpt *.wav -wandb/* \ No newline at end of file +wandb/* +models/* +outputs/* From d3bad42fff5150e7268cada067bc4d81621becb4 Mon Sep 17 00:00:00 2001 From: LastDemon99 Date: Wed, 12 Jun 2024 00:45:53 -0400 Subject: [PATCH 2/6] Add txt2audio_ui.bat - Added to start txt2audio interface without commands. - Checks previously model selected in 'config/txt2audio.json' or searches for models in it respective folder. --- config/txt2audio.json | 3 +++ run_gradio.py | 3 ++- txt2audio_ui.bat | 47 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 config/txt2audio.json create mode 100644 txt2audio_ui.bat diff --git a/config/txt2audio.json b/config/txt2audio.json new file mode 100644 index 0000000..6c447ed --- /dev/null +++ b/config/txt2audio.json @@ -0,0 +1,3 @@ +{ + "model_selected": "" +} diff --git a/run_gradio.py b/run_gradio.py index ae3ba95..18e61dd 100644 --- a/run_gradio.py +++ b/run_gradio.py @@ -15,7 +15,7 @@ def main(args): model_half=args.model_half ) interface.queue() - interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None) + interface.launch(share=True, auth=(args.username, args.password) if args.username is not None else None, inbrowser=args.inbrowser if args.inbrowser is not None else False) if __name__ == "__main__": import argparse @@ -27,5 +27,6 @@ def main(args): parser.add_argument('--username', type=str, help='Gradio username', required=False) parser.add_argument('--password', type=str, help='Gradio password', required=False) parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False) + parser.add_argument('--inbrowser', action='store_true', help='Open browser on launch', required=False) args = parser.parse_args() main(args) \ No newline at end of file diff --git a/txt2audio_ui.bat b/txt2audio_ui.bat new file mode 100644 index 0000000..6ff906f --- /dev/null +++ b/txt2audio_ui.bat @@ -0,0 +1,47 @@ +@echo off +setlocal + +set config_path='config/txt2audio.json' + +for /f "delims=" %%i in ('python -c "import json; f=open(%config_path%); data=json.load(f); f.close(); print(data['model_selected'])"') do set model_selected=%%i + +set models_path=.\models\ + +if defined model_selected ( + for %%i in (.ckpt .safetensors .pth) do ( + if exist %models_path%%model_selected%%%i ( + set model_path=%models_path%%model_selected%%%i + set model_name=%model_selected% + goto :model_found + ) + ) +) +set config_model_found=. +echo No model found in config file +echo Searching in models folder + +for /R %models_path% %%f in (*.ckpt *.safetensors *.pth) do ( + set model_path=%%~dpnxf + set model_name=%%~nf + goto :model_found +) +echo No model found +pause +exit /b + +:model_found +echo Found model: %model_name% +set model_config_path=%models_path%%model_name%.json +if not exist %model_config_path% ( + echo Model config not found. + pause + exit /b +) + +if defined config_model_found ( + python -c "import json; f=open(%config_path%); data=json.load(f); data['model_selected'] = '%model_name%'; f.close(); f=open(%config_path%, 'w'); json.dump(data, f, indent=4); f.write('\n'); f.close()" +) + +call .\venv\Scripts\activate.bat +call python run_gradio.py --ckpt-path %model_path% --model-config %model_config_path% --inbrowser +pause From 02590226eb9f2c2f4d023c097b915b8a30b72c2d Mon Sep 17 00:00:00 2001 From: LastDemon99 Date: Wed, 12 Jun 2024 05:43:51 -0400 Subject: [PATCH 3/6] Update txt2audio_ui.bat - If the models folder does not exist, it will now be created. --- txt2audio_ui.bat | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/txt2audio_ui.bat b/txt2audio_ui.bat index 6ff906f..d927c6a 100644 --- a/txt2audio_ui.bat +++ b/txt2audio_ui.bat @@ -7,6 +7,13 @@ for /f "delims=" %%i in ('python -c "import json; f=open(%config_path%); data=js set models_path=.\models\ +if not exist %models_path% ( + mkdir %models_path% + echo No model found + pause + exit /b +) + if defined model_selected ( for %%i in (.ckpt .safetensors .pth) do ( if exist %models_path%%model_selected%%%i ( From bdc9f6828d4a2438802cccf9c1c3b7e75baaaeb2 Mon Sep 17 00:00:00 2001 From: LastDemon99 Date: Wed, 12 Jun 2024 19:12:19 -0400 Subject: [PATCH 4/6] Add Save & Load features to txt2audio ui - The save of sound outputs is managed, facilitating their availability and access. - Three buttons are implemented, one to delete the prompt, another one to insert previously generation data from audio and finally a button to open the folder where the outputs are saved. --- setup.py | 3 +- stable_audio_tools/data/txt2audio_utils.py | 86 ++++++++++++++++++++++ stable_audio_tools/interface/gradio.py | 42 +++++++++-- 3 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 stable_audio_tools/data/txt2audio_utils.py diff --git a/setup.py b/setup.py index 7e7470d..0232b88 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ 'vector-quantize-pytorch==1.9.14', 'wandb==0.15.4', 'webdataset==0.2.48', - 'x-transformers<1.27.0' + 'x-transformers<1.27.0', + 'pytaglib==3.0.0' ], ) \ No newline at end of file diff --git a/stable_audio_tools/data/txt2audio_utils.py b/stable_audio_tools/data/txt2audio_utils.py new file mode 100644 index 0000000..d17c2ca --- /dev/null +++ b/stable_audio_tools/data/txt2audio_utils.py @@ -0,0 +1,86 @@ +import taglib +import os +from datetime import datetime +import platform +import subprocess + +def open_outputs_path(): + outputs = f"outputs/{datetime.now().strftime('%Y-%m-%d')}" + if not os.path.isdir(outputs): + return + outputs = os.path.abspath(outputs) + if platform.system() == "Windows": + os.startfile(outputs) + elif platform.system() == "Darwin": + subprocess.Popen(["open", outputs]) + elif "microsoft-standard-WSL2" in platform.uname().release: + subprocess.Popen(["wsl-open", outputs]) + else: + subprocess.Popen(["xdg-open", outputs]) + +def create_output_path(suffix): + outputs = f"outputs/{datetime.now().strftime('%Y-%m-%d')}" + count = 0 + + if os.path.isdir(outputs): + counts = [os.path.splitext(file)[0].split('-')[0] for file in os.listdir(outputs) if file.endswith(".wav")] + count = max([int(i) for i in counts if i.isnumeric()]) + 1 + else: + os.makedirs(outputs) + + return f"{outputs}/{'{:05d}'.format(count)}-{suffix}.wav" + +def get_generation_data(file): + with taglib.File(file) as sound: + if len(sound.tags) != 1: + return None + + data = sound.tags["TITLE"] + + if len(data) != 12: + return None + if data[0] == "None": + data[0] = "" + if data[1] == "None": + data[1] = "" + if data[5] == "None": + data[5] = 0 + + for i in range(2, 8): + data[i] = int(data[i]) + + for i in range(9, 12): + data[i] = float(data[i]) + + data[4] = float(data[4]) + + return data + +def save_generation_data(sound_path, prompt, negative_prompt, seconds_start, seconds_total, steps, preview_every, cfg_scale, seed, sampler_type, sigma_min, sigma_max, cfg_rescale): + if prompt == "": + prompt = "None" + if negative_prompt == "": + negative_prompt = "None" + + with taglib.File(sound_path, save_on_exit=True) as sound: + sound.tags["TITLE"] = [ + prompt, + negative_prompt, + str(seconds_start), + str(seconds_total), + str(steps), + str(preview_every), + str(cfg_scale), + str(seed), + str(sampler_type), + str(sigma_min), + str(sigma_max), + str(cfg_rescale)] + +def txt2audio_css(): + return """ + #prompt_options { + flex-wrap: nowrap; + height: 40px; + } + """ diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index b46c8d4..07309b9 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -163,6 +163,7 @@ def progress_callback(callback_info): else: mask_args = None + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) # Do the audio generation audio = generate_diffusion_cond( model, @@ -188,12 +189,16 @@ def progress_callback(callback_info): # Convert to WAV file audio = rearrange(audio, "b d n -> d (b n)") audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() - torchaudio.save("output.wav", audio, sample_rate) + + from stable_audio_tools.data.txt2audio_utils import create_output_path, save_generation_data + output_path = create_output_path(seed) + torchaudio.save(output_path, audio, sample_rate) + save_generation_data(output_path, prompt, negative_prompt, seconds_start, seconds_total, steps, preview_every, cfg_scale, seed, sampler_type, sigma_min, sigma_max, cfg_rescale) # Let's look at a nice spectrogram too audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) - return ("output.wav", [audio_spectrogram, *preview_images]) + return (output_path, [audio_spectrogram, *preview_images]) def generate_uncond( steps=250, @@ -380,6 +385,10 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Column(scale=6): prompt = gr.Textbox(show_label=False, placeholder="Prompt") negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") + with gr.Row(elem_id="prompt_options"): + clear_prompt = gr.Button('\U0001f5d1\ufe0f') + paste_generation_data = gr.Button('\u2199\ufe0f') + insert_generation_data = gr.File(label="Insert generation data from output.wav", file_types=[".wav"], scale=0) generate_button = gr.Button("Generate", variant='primary', scale=1) model_conditioning_config = model_config["model"].get("conditioning", None) @@ -492,11 +501,31 @@ def create_sampling_ui(model_config, inpainting=False): with gr.Column(): audio_output = gr.Audio(label="Output audio", interactive=False) - audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) - send_to_init_button = gr.Button("Send to init audio", scale=1) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + with gr.Row(): + open_outputs_folder = gr.Button("\U0001f4c1", scale=1) + send_to_init_button = gr.Button("Send to init audio", scale=1) + from stable_audio_tools.data.txt2audio_utils import open_outputs_path + open_outputs_folder.click(fn=open_outputs_path) send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) - generate_button.click(fn=generate_cond, + from stable_audio_tools.data.txt2audio_utils import get_generation_data + paste_generation_data.click(fn=get_generation_data, inputs=[insert_generation_data], outputs=[prompt, + negative_prompt, + seconds_start_slider, + seconds_total_slider, + steps_slider, + preview_every_slider, + cfg_scale_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider]) + + clear_prompt.click(fn=lambda: ("", ""), outputs=[prompt, negative_prompt]) + + generate_button.click(fn=generate_cond, inputs=inputs, outputs=[ audio_output, @@ -506,7 +535,8 @@ def create_sampling_ui(model_config, inpainting=False): def create_txt2audio_ui(model_config): - with gr.Blocks() as ui: + from stable_audio_tools.data.txt2audio_utils import txt2audio_css + with gr.Blocks(css=txt2audio_css()) as ui: with gr.Tab("Generation"): create_sampling_ui(model_config) with gr.Tab("Inpainting"): From 0e410931ab47985b2ac06c2f03a84390bba11f87 Mon Sep 17 00:00:00 2001 From: LastDemon99 Date: Thu, 13 Jun 2024 01:16:12 -0400 Subject: [PATCH 5/6] Add model change option to txt2audio ui - A drop-down list of available models is added allowing to change the selected - A button is added to updating model list in case are added to the respective folder. --- stable_audio_tools/data/txt2audio_utils.py | 44 ++++++++++++++++++++ stable_audio_tools/interface/gradio.py | 48 ++++++++++++++++++---- 2 files changed, 83 insertions(+), 9 deletions(-) diff --git a/stable_audio_tools/data/txt2audio_utils.py b/stable_audio_tools/data/txt2audio_utils.py index d17c2ca..a6a8247 100644 --- a/stable_audio_tools/data/txt2audio_utils.py +++ b/stable_audio_tools/data/txt2audio_utils.py @@ -3,6 +3,32 @@ from datetime import datetime import platform import subprocess +import json + +def set_selected_model(model_name): + if model_name in [data["name"] for data in get_models_data()]: + config = get_config() + config["model_selected"] = model_name + with open("config/txt2audio.json", "w") as file: + json.dump(config, file, indent=4) + file.write('\n') + +def get_config(): + with open("config/txt2audio.json") as file: + return json.load(file) + +def get_models_name(): + return [model["name"] for model in get_models_data()] + +def get_models_data(): + models = [] + file_types = ['.ckpt', '.safetensors', '.pth'] + for file in os.listdir("models/"): + _file = os.path.splitext(file) + config_path = f"models/{_file[0]}.json" + if _file[1] in file_types and os.path.isfile(config_path): + models.append({"name": _file[0], "path": f"models/{file}", "config_path": config_path}) + return models def open_outputs_path(): outputs = f"outputs/{datetime.now().strftime('%Y-%m-%d')}" @@ -79,8 +105,26 @@ def save_generation_data(sound_path, prompt, negative_prompt, seconds_start, sec def txt2audio_css(): return """ + div.svelte-sa48pu>*, div.svelte-sa48pu>.form>* { + flex: 1 1 0%; + flex-wrap: wrap; + min-width: min(40px, 100%); + } + + #refresh_btn { + padding: 0px; + } + + #selected_model_items div.svelte-1sk0pyu div.wrap.svelte-1sk0pyu div.wrap-inner.svelte-1sk0pyu div.secondary-wrap.svelte-1sk0pyu input.border-none.svelte-1sk0pyu { + margin: 0px; + } + #prompt_options { flex-wrap: nowrap; height: 40px; } + + #selected_model_container { + gap: 3px; + } """ diff --git a/stable_audio_tools/interface/gradio.py b/stable_audio_tools/interface/gradio.py index 07309b9..9bc3aa6 100644 --- a/stable_audio_tools/interface/gradio.py +++ b/stable_audio_tools/interface/gradio.py @@ -21,6 +21,7 @@ model = None sample_rate = 32000 sample_size = 1920000 +model_is_half = None def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): global model, sample_rate, sample_size @@ -55,6 +56,32 @@ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pr return model, model_config +def unload_model(): + global model + del model + model = None + torch.cuda.empty_cache() + gc.collect() + +def txt2audio_change_model(model_name): + from stable_audio_tools.data.txt2audio_utils import get_models_data, set_selected_model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for model_data in get_models_data(): + if model_data["name"] == model_name: + unload_model() + set_selected_model(model_name) + model_config = get_model_config_from_path(model_data["config_path"]) + load_model(model_config, model_data["path"], model_half=model_is_half, device=device) + return model_name + +def get_model_config_from_path(model_config_path): + if model_config_path is not None: + # Load config from json file + with open(model_config_path) as f: + return json.load(f) + else: + return None + def generate_cond( prompt, negative_prompt=None, @@ -533,10 +560,17 @@ def create_sampling_ui(model_config, inpainting=False): ], api_name="generate") - def create_txt2audio_ui(model_config): - from stable_audio_tools.data.txt2audio_utils import txt2audio_css + from stable_audio_tools.data.txt2audio_utils import txt2audio_css, get_models_name, get_config with gr.Blocks(css=txt2audio_css()) as ui: + with gr.Column(elem_id="selected_model_container"): + gr.HTML('', visible=True) + with gr.Row(): + selected_model_dropdown = gr.Dropdown(get_models_name(), container=False, value=get_config()["model_selected"], interactive=True, scale=0, min_width=265, elem_id="selected_model_items") + selected_model_dropdown.change(fn=txt2audio_change_model, inputs=selected_model_dropdown, outputs=selected_model_dropdown) + refresh_models_button = gr.Button("\U0001f504", scale=0, elem_id="refresh_btn") + refresh_models_button.click(fn=lambda: gr.Dropdown(choices=get_models_name()), outputs=selected_model_dropdown) + gr.HTML('
', visible=True) with gr.Tab("Generation"): create_sampling_ui(model_config) with gr.Tab("Inpainting"): @@ -685,16 +719,12 @@ def create_lm_ui(model_config): return ui def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): + global model_is_half + model_is_half = model_half assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" - if model_config_path is not None: - # Load config from json file - with open(model_config_path) as f: - model_config = json.load(f) - else: - model_config = None - + model_config = get_model_config_from_path(model_config_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) From 888282eb316aaa35237dfc72cc8da37b79ec6694 Mon Sep 17 00:00:00 2001 From: LastDemon99 Date: Thu, 13 Jun 2024 01:46:55 -0400 Subject: [PATCH 6/6] Update open outputs folder to txt2audio ui - Before, when you click, the folder open only if it found recently generated images, now if it does not find recent generations it opens the folder with all generations, if they do not exist generations it will not open anything. --- stable_audio_tools/data/txt2audio_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/stable_audio_tools/data/txt2audio_utils.py b/stable_audio_tools/data/txt2audio_utils.py index a6a8247..e6b6027 100644 --- a/stable_audio_tools/data/txt2audio_utils.py +++ b/stable_audio_tools/data/txt2audio_utils.py @@ -31,9 +31,15 @@ def get_models_data(): return models def open_outputs_path(): - outputs = f"outputs/{datetime.now().strftime('%Y-%m-%d')}" + outputs_dir = "outputs/" + outputs = outputs_dir + datetime.now().strftime('%Y-%m-%d') + if not os.path.isdir(outputs): - return + if not os.path.isdir(outputs_dir): + return + else: + outputs = outputs_dir + outputs = os.path.abspath(outputs) if platform.system() == "Windows": os.startfile(outputs)