Skip to content

Commit

Permalink
Stable Diffusion XL Turbo (#2959)
Browse files Browse the repository at this point in the history
  • Loading branch information
attila-dusnoki-htec authored and Ted Themistokleous committed Apr 25, 2024
1 parent fad93df commit e5432cb
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 140 deletions.
40 changes: 35 additions & 5 deletions examples/diffusion/python_stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@ export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH

Get models with huggingface-cli

### Base version

```bash
huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 text_encoder/model.onnx text_encoder_2/model.onnx text_encoder_2/model.onnx_data unet/model.onnx unet/model.onnx_data vae_decoder/model.onnx --local-dir models/sdxl-1.0-base/ --local-dir-use-symlinks False
```

### Opt version

```bash
huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 vae_decoder/model.onnx --local-dir models/sdxl-1.0-base/ --local-dir-use-symlinks False
huggingface-cli download stabilityai/stable-diffusion-xl-1.0-tensorrt sdxl-1.0-base/clip.opt/model.onnx sdxl-1.0-base/clip2.opt/model.onnx sdxl-1.0-base/unetxl.opt/model.onnx sdxl-1.0-base/unetxl.opt/435d4c0a-2d32-11ee-8476-0242c0a80101 --local-dir models/ --local-dir-use-symlinks False
```
*Note: `models/sdxl-1.0-base` will be used in the scripts.*

Convert CLIP models to expose "hidden_state" as output.

Expand All @@ -44,10 +51,20 @@ python clip_modifier.py -i models/sdxl-1.0-base/clip.opt/model.onnx -o models/sd
python clip_modifier.py -i models/sdxl-1.0-base/clip2.opt/model.onnx -o models/sdxl-1.0-base/clip2.opt.mod/model.onnx
```

### Turbo version

```bash
huggingface-cli download stabilityai/sdxl-turbo text_encoder/model.onnx text_encoder_2/model.onnx text_encoder_2/model.onnx_data unet/model.onnx unet/model.onnx_data vae_decoder/model.onnx --local-dir models/sdxl-turbo/ --local-dir-use-symlinks False
```

### Running txt2img

Run the text-to-image script with the following example prompt and seed:

Set `pipeline-type` based on the version of models you downloaded: `sdxl` for base, `sdxl-opt` for opt, `sdxl-turbo` for turbo

```bash
python txt2img.py --prompt "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --seed 42 --output jungle_astro.jpg
python txt2img.py --prompt "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --seed 42 --output jungle_astro.jpg --pipeline-type <model-version>
```
*Note: The first run will compile the models and cache them to make subsequent runs faster.*

Expand All @@ -61,14 +78,25 @@ Note: requires `Console application` to work

Get models with huggingface-cli

Note: Only the opt version provides an onnx model, but can be used for all 3 version (`sdxl`, `sdxl-opt`, `sdxl-turbo`)

```bash
huggingface-cli download stabilityai/stable-diffusion-xl-1.0-tensorrt sdxl-1.0-refiner/unetxl.opt/model.onnx sdxl-1.0-refiner/unetxl.opt/6ed855ee-2d70-11ee-af8e-0242c0a80101 sdxl-1.0-refiner/unetxl.opt/6e186582-2d74-11ee-8aa7-0242c0a80102 --local-dir models/ --local-dir-use-symlinks False
huggingface-cli download stabilityai/stable-diffusion-xl-1.0-tensorrt sdxl-1.0-refiner/clip2.opt/model.onnx sdxl-1.0-refiner/unetxl.opt/model.onnx sdxl-1.0-refiner/unetxl.opt/6ed855ee-2d70-11ee-af8e-0242c0a80101 sdxl-1.0-refiner/unetxl.opt/6e186582-2d74-11ee-8aa7-0242c0a80102 --local-dir models/ --local-dir-use-symlinks False
```

Convert CLIP2 model to expose "hidden_state" as output.

```bash
# clip2.opt
python clip_modifier.py -i models/sdxl-1.0-refiner/clip2.opt/model.onnx -o models/sdxl-1.0-refiner/clip2.opt.mod/model.onnx
```

Run the text-to-image script with the following example prompt and seed:

Set `pipeline-type` based on which version of models you have: `sdxl` for base, `sdxl-opt` for opt, `sdxl-turbo` for turbo

```bash
python txt2img.py --prompt "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --seed 42 --output refined_jungle_astro.jpg --refiner-onnx-model-path models/sdxl-1.0-refiner
python txt2img.py --prompt "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --seed 42 --output refined_jungle_astro.jpg --pipeline-type <model-version> --use-refiner
```

## Gradio application
Expand All @@ -83,8 +111,10 @@ pip install -r gradio_requirements.txt

Usage

Set `pipeline-type` based on which version of models you have: `sdxl` for base, `sdxl-opt` for opt, `sdxl-turbo` for turbo

```bash
python gradio_app.py -p "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
python gradio_app.py -p "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --pipeline-type <model-version>
```

This will load the models (which can take several minutes), and when the setup is ready, starts a server on `http://127.0.0.1:7860`.
110 changes: 71 additions & 39 deletions examples/diffusion/python_stable_diffusion_xl/gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,58 +24,90 @@

from txt2img import StableDiffusionMGX, get_args
import gradio as gr
import sys


class PrintWrapper(object):
def __init__(self, org_handle):
self.org_handle = org_handle
self.log = ""

def wrapper_write(x):
self.log += x
return org_handle.write(x)

self.wrapper_write = wrapper_write

def __getattr__(self, attr):
return self.wrapper_write if attr == 'write' else getattr(
self.org_handle, attr)

def get_log(self):
return self.log


def main():
args = get_args()
# Note: This will load the models, which can take several minutes
sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path,
sd = StableDiffusionMGX(args.pipeline_type, args.onnx_model_path,
args.compiled_model_path, args.use_refiner,
args.refiner_onnx_model_path,
args.refiner_compiled_model_path, args.fp16,
args.force_compile, args.exhaustive_tune)
sd.warmup(5)

def gr_wrapper(prompt, negative_prompt, steps, seed, scale,
def gr_wrapper(prompt, negative_prompt, steps, seed, scale, refiner_steps,
aesthetic_score, negative_aesthetic_score):
result = sd.run(
str(prompt),
str(negative_prompt),
int(steps),
int(seed),
float(scale),
float(aesthetic_score),
float(negative_aesthetic_score),
)
return StableDiffusionMGX.convert_to_rgb_image(result)
img = None
try:
oldStdout, oldStderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = PrintWrapper(sys.stdout), PrintWrapper(
sys.stderr)
result = sd.run(
str(prompt),
str(negative_prompt),
int(steps),
int(seed),
float(scale),
int(refiner_steps),
float(aesthetic_score),
float(negative_aesthetic_score),
)
img = StableDiffusionMGX.convert_to_rgb_image(result)
finally:
log = ''.join([sys.stdout.get_log(), sys.stderr.get_log()])
sys.stdout, sys.stderr = oldStdout, oldStderr
return img, log

use_refiner = bool(args.refiner_onnx_model_path
or args.refiner_compiled_model_path)
demo = gr.Interface(
gr_wrapper,
[
gr.Textbox(value=args.prompt, label="Prompt"),
gr.Textbox(value=args.negative_prompt,
label="Negative prompt (Optional)"),
gr.Slider(
1, 100, step=1, value=args.steps, label="Number of steps"),
gr.Textbox(value=args.seed, label="Random seed"),
gr.Slider(
1, 20, step=0.1, value=args.scale, label="Guidance scale"),
gr.Slider(1,
20,
step=0.1,
value=args.refiner_aesthetic_score,
label="Aesthetic score",
visible=use_refiner),
gr.Slider(1,
20,
step=0.1,
value=args.refiner_negative_aesthetic_score,
label="Negative Aesthetic score",
visible=use_refiner),
],
demo = gr.Interface(gr_wrapper, [
gr.Textbox(value=args.prompt, label="Prompt"),
gr.Textbox(value=args.negative_prompt,
label="Negative prompt (Optional)"),
gr.Slider(1, 100, step=1, value=args.steps, label="Number of steps"),
gr.Textbox(value=args.seed, label="Random seed"),
gr.Slider(1, 20, step=0.1, value=args.scale, label="Guidance scale"),
gr.Slider(0,
100,
step=1,
value=args.refiner_steps,
label="Number of refiner steps. (Use 0 to skip it)",
visible=args.use_refiner),
gr.Slider(1,
20,
step=0.1,
value=args.refiner_aesthetic_score,
label="Aesthetic score (Refiner)",
visible=args.use_refiner),
gr.Slider(1,
20,
step=0.1,
value=args.refiner_negative_aesthetic_score,
label="Negative Aesthetic score (Refiner)",
visible=args.use_refiner),
], [
"image",
)
gr.Textbox(placeholder="Output log of the run", label="Output log")
])
demo.launch()


Expand Down
Loading

0 comments on commit e5432cb

Please sign in to comment.