Skip to content

Commit

Permalink
more batch sizes for SD2.1 (#3300)
Browse files Browse the repository at this point in the history
  • Loading branch information
richagadgil authored Aug 9, 2024
1 parent 408b6c3 commit 6b0f10f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 27 deletions.
6 changes: 3 additions & 3 deletions examples/diffusion/python_stable_diffusion_21/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ optimum-cli export onnx --model stabilityai/stable-diffusion-2-1 models/sd21-onn
```
*Note: `models/sd21-onnx` will be used in the scripts.*

Run the text-to-image script with the following example prompt and seed:
Run the text-to-image script with the following example prompt and seed (optionally, you can change the batch size / number of images generated for that prompt)

```bash
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg
python txt2img.py --prompt "a photograph of an astronaut riding a horse" --seed 13 --output astro_horse.jpg --batch 1
```
*Note: The first run will compile the models and cache them to make subsequent runs faster.*
*Note: The first run will compile the models and cache them to make subsequent runs faster. New batch sizes will result in the models re-compiling.*

The result should look like this:

Expand Down
4 changes: 2 additions & 2 deletions examples/diffusion/python_stable_diffusion_21/gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main():
args = get_args()
# Note: This will load the models, which can take several minutes
sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path,
args.fp16, args.force_compile,
args.fp16, args.batch, args.force_compile,
args.exhaustive_tune)
sd.warmup(5)

Expand All @@ -51,7 +51,7 @@ def gr_wrapper(prompt, negative_prompt, steps, seed, scale):
gr.Slider(
1, 20, step=0.1, value=args.scale, label="Guidance scale"),
],
"image",
gr.Gallery(),
)
demo.launch()

Expand Down
61 changes: 39 additions & 22 deletions examples/diffusion/python_stable_diffusion_21/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def get_args():
help="Number of steps",
)

parser.add_argument("-b",
"--batch",
type=int,
default=1,
help="Batch count or number of images to produce")

parser.add_argument(
"-p",
"--prompt",
Expand Down Expand Up @@ -198,7 +204,7 @@ def allocate_torch_tensors(model):


class StableDiffusionMGX():
def __init__(self, onnx_model_path, compiled_model_path, fp16,
def __init__(self, onnx_model_path, compiled_model_path, fp16, batch,
force_compile, exhaustive_tune):
model_id = "stabilityai/stable-diffusion-2-1"
print(f"Using {model_id}")
Expand All @@ -215,17 +221,20 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16,
elif "all" in fp16:
fp16 = ["vae", "clip", "unet"]

self.batch = batch

print("Load models...")
self.models = {
"vae":
StableDiffusionMGX.load_mgx_model(
"vae_decoder", {"latent_sample": [1, 4, 64, 64]},
"vae_decoder", {"latent_sample": [self.batch, 4, 64, 64]},
onnx_model_path,
compiled_model_path=compiled_model_path,
use_fp16="vae" in fp16,
force_compile=force_compile,
exhaustive_tune=exhaustive_tune,
offload_copy=False),
offload_copy=False,
batch=self.batch),
"clip":
StableDiffusionMGX.load_mgx_model(
"text_encoder", {"input_ids": [2, 77]},
Expand All @@ -238,16 +247,17 @@ def __init__(self, onnx_model_path, compiled_model_path, fp16,
"unet":
StableDiffusionMGX.load_mgx_model(
"unet", {
"sample": [2, 4, 64, 64],
"encoder_hidden_states": [2, 77, 1024],
"sample": [2 * self.batch, 4, 64, 64],
"encoder_hidden_states": [2 * self.batch, 77, 1024],
"timestep": [1],
},
onnx_model_path,
compiled_model_path=compiled_model_path,
use_fp16="unet" in fp16,
force_compile=force_compile,
exhaustive_tune=exhaustive_tune,
offload_copy=False)
offload_copy=False,
batch=self.batch)
}

self.tensors = {
Expand Down Expand Up @@ -317,7 +327,7 @@ def run(self, prompt, negative_prompt, steps, seed, scale):
f"Creating random input data ({1}x{4}x{64}x{64}) (latents) with seed={seed}..."
)
latents = torch.randn(
(1, 4, 64, 64),
(self.batch, 4, 64, 64),
generator=torch.manual_seed(seed)).to(device="cuda")

print("Apply initial noise sigma\n")
Expand Down Expand Up @@ -369,12 +379,13 @@ def load_mgx_model(name,
use_fp16=False,
force_compile=False,
exhaustive_tune=False,
offload_copy=True):
offload_copy=True,
batch=1):
print(f"Loading {name} model...")
if compiled_model_path is None:
compiled_model_path = onnx_model_path
onnx_file = f"{onnx_model_path}/{name}/model.onnx"
mxr_file = f"{compiled_model_path}/{name}/model_{'fp16' if use_fp16 else 'fp32'}_{'gpu' if not offload_copy else 'oc'}.mxr"
mxr_file = f"{compiled_model_path}/{name}/model_{'fp16' if use_fp16 else 'fp32'}_b{batch}_{'gpu' if not offload_copy else 'oc'}.mxr"
if not force_compile and os.path.isfile(mxr_file):
print(f"Found mxr, loading it from {mxr_file}")
model = mgx.load(mxr_file, format="msgpack")
Expand Down Expand Up @@ -410,14 +421,16 @@ def get_embeddings(self, prompt_tokens):
copy_tensor_sync(self.tensors["clip"]["input_ids"],
prompt_tokens.input_ids.to(torch.int32))
run_model_sync(self.models["clip"], self.model_args["clip"])
return self.tensors["clip"][get_output_name(0)]
text_embeds = self.tensors["clip"][get_output_name(0)]
return torch.cat(
[torch.cat([i] * self.batch) for i in text_embeds.split(1)])

@staticmethod
def convert_to_rgb_image(image):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
return Image.fromarray(images[0])
return [Image.fromarray(images[i]) for i in range(images.shape[0])]

@staticmethod
def save_image(pil_image, filename="output.png"):
Expand Down Expand Up @@ -458,14 +471,17 @@ def warmup(self, num_runs):
self.profile_start("warmup")
copy_tensor_sync(self.tensors["clip"]["input_ids"],
torch.ones((2, 77)).to(torch.int32))
copy_tensor_sync(self.tensors["unet"]["sample"],
torch.randn((2, 4, 64, 64)).to(torch.float32))
copy_tensor_sync(self.tensors["unet"]["encoder_hidden_states"],
torch.randn((2, 77, 1024)).to(torch.float32))
copy_tensor_sync(
self.tensors["unet"]["sample"],
torch.randn((2 * self.batch, 4, 64, 64)).to(torch.float32))
copy_tensor_sync(
self.tensors["unet"]["encoder_hidden_states"],
torch.randn((2 * self.batch, 77, 1024)).to(torch.float32))
copy_tensor_sync(self.tensors["unet"]["timestep"],
torch.atleast_1d(torch.randn(1).to(torch.int64)))
copy_tensor_sync(self.tensors["vae"]["latent_sample"],
torch.randn((1, 4, 64, 64)).to(torch.float32))
copy_tensor_sync(
self.tensors["vae"]["latent_sample"],
torch.randn((self.batch, 4, 64, 64)).to(torch.float32))

for _ in range(num_runs):
run_model_sync(self.models["clip"], self.model_args["clip"])
Expand All @@ -478,7 +494,7 @@ def warmup(self, num_runs):
args = get_args()

sd = StableDiffusionMGX(args.onnx_model_path, args.compiled_model_path,
args.fp16, args.force_compile,
args.fp16, args.batch, args.force_compile,
args.exhaustive_tune)
print("Warmup")
sd.warmup(5)
Expand All @@ -492,7 +508,8 @@ def warmup(self, num_runs):
sd.cleanup()

print("Convert result to rgb image...")
image = StableDiffusionMGX.convert_to_rgb_image(result)
filename = args.output if args.output else f"output_s{args.seed}_t{args.steps}.png"
StableDiffusionMGX.save_image(image, filename)
print(f"Image saved to {filename}")
images = StableDiffusionMGX.convert_to_rgb_image(result)
for i, image in enumerate(images):
filename = f"{args.batch}_{args.output}" if args.output else f"output_s{args.seed}_t{args.steps}_{i}.png"
StableDiffusionMGX.save_image(image, filename)
print(f"Image saved to {filename}")

0 comments on commit 6b0f10f

Please sign in to comment.