Skip to content

Commit

Permalink
fix Bark in React UI, add Max Generation Duration (#361)
Browse files Browse the repository at this point in the history
* fix Bark in React UI, add Max Generation Duration
  • Loading branch information
rsxdalv authored Aug 5, 2024
1 parent 4a22ec3 commit 62d8893
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 51 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ config.json
/data/models/xtts/
/data/models/hubert_base.pt
/data/models/rmvpe.pt
/data/models/audiocraft_plus/

# Ignore temporary files
temp/
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ List of models: Bark, MusicGen + AudioGen, Tortoise, RVC, Vocos, Demucs, Seamles

## Changelog

Aug 5:
* Fix Bark in React UI, add Max Generation Duration.
* Change AudioCraft Plus extension models directory to ./data/models/audiocraft_plus/
* Improve model unloading for MusicGen and AudioGen. Add unload models button to MusicGen and AudioGen.

Aug 4:
* Add XTTS-RVC-UI extension, XTTS Fine-tuning demo extension.

Aug 3:
* Riffusion extension
* Add Riffusion extension, AudioCraft Mac extension, Bark Legacy extension.

Aug 2:
* Add deprecation warning to old installer.
Expand Down
22 changes: 22 additions & 0 deletions react-ui/src/components/BarkInputs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ import {
} from "./BarkRadios";
import { SeedInput } from "./SeedInput";
import { HandleChange } from "../types/HandleChange";
import { GenericSlider } from "./GenericSlider";

const MaxGenDuration = ({
barkGenerationParams,
handleChange,
}: {
barkGenerationParams: BarkGenerationParams;
handleChange: HandleChange;
}) =>
<GenericSlider
label="Max generation duration (s)"
name="max_gen_duration_s"
min="0.1"
max="18"
step="0.1"
params={barkGenerationParams}
handleChange={handleChange}
/>

export const BarkInputs = ({
barkGenerationParams,
Expand Down Expand Up @@ -65,6 +83,10 @@ export const BarkInputs = ({
barkGenerationParams={barkGenerationParams}
handleChange={handleChange}
/>
<MaxGenDuration
barkGenerationParams={barkGenerationParams}
handleChange={handleChange}
/>
<TextTemperature
barkGenerationParams={barkGenerationParams}
handleChange={handleChange}
Expand Down
65 changes: 41 additions & 24 deletions react-ui/src/hooks/useLocalStorage.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import { Dispatch, SetStateAction, useCallback, useEffect, useState } from "react";
import {
Dispatch,
SetStateAction,
useCallback,
useEffect,
useState,
} from "react";

const defaultNamespace = "tts-generation-webui__";
export const defaultNamespace = "tts-generation-webui__";

const readLocalStorage = (key: string) => {
const prefixedKey = defaultNamespace + key;
Expand All @@ -22,7 +28,8 @@ export const updateLocalStorageWithFunction = (key: string, value: any) =>
export default function useLocalStorage<T>(
key: string,
initialValue: T,
namespace = defaultNamespace
namespace = defaultNamespace,
extend = false
): [T, Dispatch<SetStateAction<T>>] {
const [storedValue, setStoredValue] = useState(initialValue);
// We will use this flag to trigger the reading from localStorage
Expand All @@ -40,7 +47,11 @@ export default function useLocalStorage<T>(
}
try {
const item = window.localStorage.getItem(prefixedKey);
return item ? (JSON.parse(item) as T) : initialValue;
return item
? extend
? ({ ...initialValue, ...JSON.parse(item) } as T)
: (JSON.parse(item) as T)
: initialValue;
} catch (error) {
console.error(error);
return initialValue;
Expand All @@ -53,29 +64,35 @@ export default function useLocalStorage<T>(
setFirstLoadDone(true);
}, [initialValue, prefixedKey]);

const setLocalValue = useCallback((value: T) => {
if (!firstLoadDone) {
return;
}
const setLocalValue = useCallback(
(value: T) => {
if (!firstLoadDone) {
return;
}

try {
if (typeof window !== "undefined") {
window.localStorage.setItem(prefixedKey, JSON.stringify(value));
try {
if (typeof window !== "undefined") {
window.localStorage.setItem(prefixedKey, JSON.stringify(value));
}
} catch (error) {
console.log(error);
}
} catch (error) {
console.log(error);
}
}, [firstLoadDone, prefixedKey]);
},
[firstLoadDone, prefixedKey]
);

const setValue: Dispatch<SetStateAction<T>> = useCallback((value) => {
// Allow value to be a function so we have the same API as useState
// const valueToStore = value instanceof Function ? value(storedValue) : value;
setStoredValue(x => {
const newValue = value instanceof Function ? value(x) : value;
setLocalValue(newValue);
return newValue;
});
}, [setLocalValue]);
const setValue: Dispatch<SetStateAction<T>> = useCallback(
(value) => {
// Allow value to be a function so we have the same API as useState
// const valueToStore = value instanceof Function ? value(storedValue) : value;
setStoredValue((x) => {
const newValue = value instanceof Function ? value(x) : value;
setLocalValue(newValue);
return newValue;
});
},
[setLocalValue]
);

// watch localStorage changes
// useEffect(() => {
Expand Down
34 changes: 15 additions & 19 deletions react-ui/src/pages/api/gradio/[name].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ export default async function handler(
res.status(200).json(result);
}

const defaultBackend = process.env.GRADIO_BACKEND || process.env.GRADIO_BACKEND_AUTOMATIC || "http://127.0.0.1:7860/";
const defaultBackend =
process.env.GRADIO_BACKEND ||
process.env.GRADIO_BACKEND_AUTOMATIC ||
"http://127.0.0.1:7770/";
const getClient = () => client(defaultBackend, {});

type GradioChoices = {
Expand Down Expand Up @@ -152,10 +155,12 @@ async function bark({
old_generation_dropdown,
seed: seed_input,
history_prompt_semantic_dropdown,
max_gen_duration_s,
}) {
const result = await gradioPredict<
[
GradioFile, // audio
// GradioFile, // audio
{ value: GradioFile; label: string }, // npz
string, // image
Object, // save_button
Object, // continue_button
Expand All @@ -180,28 +185,19 @@ async function bark({
old_generation_dropdown,
seed_input,
history_prompt_semantic_dropdown,
max_gen_duration_s,
]);

const [
audio,
image,
save_button,
continue_button,
buttons_row,
npz,
seed,
json_text,
history_bundle_name_data,
] = result?.data;
const [audio_update, npz, json_text, history_bundle_name_data] = result?.data;

const audio = audio_update.value;
const fixedAudio = {
...audio,
data: `http://127.0.0.1:7770/file=${audio.name}`,
};
return {
audio,
image,
save_button,
continue_button,
buttons_row,
audio: fixedAudio,
npz,
seed,
json_text,
history_bundle_name_data,
};
Expand Down
4 changes: 3 additions & 1 deletion react-ui/src/tabs/BarkGenerationParams.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const inputs = {
seed: "123",
history_prompt_semantic_dropdown:
"voices\\2023-06-18_21-51-07__bark__continued_generation.npz",
max_gen_duration_s: 15,
};

export type BarkGenerationParams = {
Expand All @@ -36,13 +37,14 @@ export type BarkGenerationParams = {
old_generation_dropdown: string; // string
seed: string; // string in 'parameter_40' Textbox component
history_prompt_semantic_dropdown: string; // string
max_gen_duration_s: number; // number in 'Max generation duration (s)' Number component
};

export const initialState: BarkGenerationParams = {
...inputs,
};

export const barkGenerationId = "bark_generation-tab";
export const barkGenerationId = "bark_generation-tab.v2";

export const sendToBarkAsVoice = (old_generation_dropdown?: string) => {
if (!old_generation_dropdown) return;
Expand Down
5 changes: 0 additions & 5 deletions react-ui/src/tabs/BarkResult.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@ import { GradioFile } from "../types/GradioFile";

export type BarkResult = {
audio: GradioFile;
image: string;
save_button: Object;
continue_button: Object;
buttons_row: Object;
npz: string;
seed: null;
json_text: {
_version: string;
_hash_version: string;
Expand Down
3 changes: 2 additions & 1 deletion src/extensions_loader/interface_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def _handle_package(package_name, title_name, requirements):
pip_install_wrapper(requirements, title_name),
outputs=[gr.HTML()],
)
main_tab()
with gr.Tabs():
main_tab()
except Exception as e:
generic_error_tab_advanced(
e, name=title_name + " Extension", requirements=requirements
Expand Down
19 changes: 19 additions & 0 deletions src/musicgen/musicgen_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ def log_generation_musicgen(
print(key, ":", value)


def unload_models():
global MODEL
MODEL = None
import gc

gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "Unloaded"


def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarray]]):
model = params["model"]
text = params["text"]
Expand All @@ -147,6 +158,7 @@ def generate(params: MusicGenGeneration, melody_in: Optional[Tuple[int, np.ndarr

global MODEL
if MODEL is None or MODEL.name != model:
unload_models()
MODEL = load_model(model)

MODEL.set_generation_params(
Expand Down Expand Up @@ -308,6 +320,13 @@ def generation_tab_musicgen():
)
seed, set_old_seed_button, _ = setup_seed_ui_musicgen()

unload_models_button = gr.Button("Unload models")
unload_models_button.click(
fn=unload_models,
outputs=[unload_models_button],
api_name="musicgen_audiogen_unload_models",
)

with gr.Column():
output = gr.Audio(
label="Generated Music",
Expand Down

0 comments on commit 62d8893

Please sign in to comment.