Skip to content

Commit

Permalink
Merge pull request #13 from leptonai/yqdemo
Browse files Browse the repository at this point in the history
feat(examples): add sdxl example
  • Loading branch information
Yangqing authored Aug 16, 2023
2 parents 4c66646 + 245fd46 commit e66ea00
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 0 deletions.
5 changes: 5 additions & 0 deletions advanced/sdxl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Stable Diffusion XL model

This folder shows how one can build a service with Stable Diffusion's most recent XL model.

More contents to be written.
6 changes: 6 additions & 0 deletions advanced/sdxl/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
diffusers>=0.19.0
gradio
invisible-watermark
leptonai
torch

257 changes: 257 additions & 0 deletions advanced/sdxl/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from io import BytesIO
from typing import Optional

from diffusers import DiffusionPipeline
import gradio as gr
import torch

from leptonai.photon import Photon, PNGResponse


class SDXL(Photon):
requirement_dependency = [
"gradio",
"torch",
"diffusers>=0.19.0",
"invisible-watermark",
]

def init(self):
cuda_available = torch.cuda.is_available()

if cuda_available:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

# load both base & refiner
self.base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
if cuda_available:
self.base.to("cuda")
# torch.compile is affected by the following issue. If you encounter problems,
# comment the torch.compile line.
# https://github.com/huggingface/diffusers/issues/4370
self.base.unet = torch.compile(
self.base.unet, mode="reduce-overhead", fullgraph=True
)

self._refiner = None

@property
def refiner(self):
if self._refiner is None:
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.base.text_encoder_2,
vae=self.base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
if torch.cuda.is_available():
pipe.to("cuda")

# torch.compile is affected by the following issue. If you encounter problems,
# comment the torch.compile line.
pipe.unet = torch.compile(
pipe.unet, mode="reduce-overhead", fullgraph=True
)
self._refiner = pipe
return self._refiner

@Photon.handler(
"run",
example={
"prompt": "A majestic lion jumping from a big stone at night",
"n_steps": 40,
"high_noise_frac": 0.8,
},
)
def run(
self,
prompt: str,
negative_prompt: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
guidance_scale: Optional[float] = 5.0,
seed: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
high_noise_frac: Optional[float] = 0.8,
use_refiner: Optional[bool] = True,
) -> PNGResponse:
images = self._run(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
samples=1,
seed=seed,
num_inference_steps=num_inference_steps,
high_noise_frac=high_noise_frac,
use_refiner=use_refiner,
)

img_io = BytesIO()
images[0].save(img_io, format="PNG", quality="keep")
img_io.seek(0)
return PNGResponse(img_io)

def _run(
self,
prompt,
negative_prompt,
width,
height,
guidance_scale,
samples,
seed,
num_inference_steps,
high_noise_frac,
use_refiner,
):
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None

if samples > 1:
prompt = [prompt] * samples
if negative_prompt is not None:
negative_prompt = [negative_prompt] * samples
generator = [generator] * samples

base_extra_kwargs = {}
if use_refiner:
base_extra_kwargs["output_type"] = "latent"
base_extra_kwargs["denoising_end"] = high_noise_frac
# run both experts
images = self.base(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
**base_extra_kwargs,
).images
if use_refiner:
images = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
denoising_start=high_noise_frac,
image=images,
).images
return images

@Photon.handler(mount=True)
def ui(self):
blocks = gr.Blocks()

with blocks:
with gr.Group():
with gr.Box():
with gr.Column(scale=3):
with gr.Row():
prompt = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
with gr.Row():
negative_prompt = gr.Textbox(
label="Enter your negative prompt",
show_label=False,
max_lines=1,
placeholder="Enter your negative prompt",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
with gr.Column(scale=1):
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")

with gr.Row(elem_id="advanced-options-1"):
samples = gr.Slider(
label="Images", minimum=1, maximum=4, value=1, step=1
)
width = gr.Slider(
label="Width",
minimum=64,
maximum=1024,
value=512,
step=8,
)
height = gr.Slider(
label="Height",
minimum=64,
maximum=1024,
value=512,
step=8,
)
steps = gr.Slider(
label="Steps", minimum=1, maximum=50, value=25, step=1
)
with gr.Row(elem_id="advanced-options-2"):
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
high_noise_frac = gr.Slider(
label="Denoising fraction",
minimum=0,
maximum=1,
value=0.8,
step=0.1,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
value=142857,
step=1,
)
use_refiner = gr.Checkbox(label="Use refiner", value=True)
btn.click(
self._run,
inputs=[
prompt,
negative_prompt,
width,
height,
scale,
samples,
seed,
steps,
high_noise_frac,
use_refiner,
],
outputs=gallery,
)

return blocks


if __name__ == "__main__":
p = SDXL()
p.launch()

0 comments on commit e66ea00

Please sign in to comment.