diff --git a/python/turbine_models/custom_models/sd_inference/controlnet.py b/python/turbine_models/custom_models/sd_inference/controlnet.py new file mode 100644 index 000000000..487c514c1 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/controlnet.py @@ -0,0 +1,178 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import ControlNetModel as CNetModel + +import safetensors +import argparse +import re + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="lllyasviel/control_v11p_sd15_canny", +) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + + +class ControlNetModel(torch.nn.Module): + def __init__( + self, model_id="lllyasviel/control_v11p_sd15_canny", low_cpu_mem_usage=False + ): + super().__init__() + self.cnet = CNetModel.from_pretrained( + model_id, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.in_channels = self.cnet.config.in_channels + self.train(False) + + def forward( + self, + latent, + timestep, + text_embedding, + stencil_image_input, + ): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + # TODO: guidance NOT NEEDED change in `get_input_info` later + latents = torch.cat([latent] * 2) # needs to be same as controlledUNET latents + stencil_image = torch.cat( + [stencil_image_input] * 2 + ) # needs to be same as controlledUNET latents + ( + down_block_res_samples, + mid_block_res_sample, + ) = self.cnet.forward( + latents, + timestep, + encoder_hidden_states=text_embedding, + controlnet_cond=stencil_image, + return_dict=False, + ) + return tuple(list(down_block_res_samples) + [mid_block_res_sample]) + + +def export_controlnet_model( + controlnet_model, + hf_model_name, + batch_size, + height, + width, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + max_alloc=None, +): + mapper = {} + utils.save_external_weights( + mapper, controlnet_model, external_weights, external_weight_path + ) + + class CompiledControlnet(CompiledModule): + if external_weights: + params = export_parameters( + controlnet_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(controlnet_model) + + def main( + self, + latent=AbstractTensor(1, 4, 512, 512, dtype=torch.float32), + timestep=AbstractTensor(1, dtype=torch.float32), + text_embedding=AbstractTensor(2, 72, 768, dtype=torch.float32), + stencil_image_input=AbstractTensor(1, 3, 4096, 4096, dtype=torch.float32), + ): + return jittable(controlnet_model.forward)( + latent, + timestep, + text_embedding, + stencil_image_input, + ) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledControlnet(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +if __name__ == "__main__": + args = parser.parse_args() + controlnet_model = ControlNetModel( + args.hf_model_name, + ) + mod_str = export_controlnet_model( + controlnet_model, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + + if mod_str is None: + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/unet.py b/python/turbine_models/custom_models/sd_inference/unet.py index 272c7af7f..f3c1e57c8 100644 --- a/python/turbine_models/custom_models/sd_inference/unet.py +++ b/python/turbine_models/custom_models/sd_inference/unet.py @@ -53,10 +53,23 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument( + "--controlled", + dest="controlled", + action="store_true", + help="Whether or not to use controlled unet (for use with controlnet)", +) +parser.add_argument( + "--no-controlled", + dest="controlled", + action="store_false", + help="Whether or not to use controlled unet (for use with controlnet)", +) +parser.set_defaults(controlled=False) class UnetModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token): + def __init__(self, hf_model_name, hf_auth_token, is_controlled): super().__init__() self.unet = UNet2DConditionModel.from_pretrained( hf_model_name, @@ -64,8 +77,12 @@ def __init__(self, hf_model_name, hf_auth_token): token=hf_auth_token, ) self.guidance_scale = 7.5 + if is_controlled: + self.forward = self.forward_controlled + else: + self.forward = self.forward_default - def forward(self, sample, timestep, encoder_hidden_states): + def forward_default(self, sample, timestep, encoder_hidden_states): samples = torch.cat([sample] * 2) unet_out = self.unet.forward( samples, timestep, encoder_hidden_states, return_dict=False @@ -76,6 +93,70 @@ def forward(self, sample, timestep, encoder_hidden_states): ) return noise_pred + def forward_controlled( + self, + sample, + timestep, + encoder_hidden_states, + control1, + control2, + control3, + control4, + control5, + control6, + control7, + control8, + control9, + control10, + control11, + control12, + control13, + scale1, + scale2, + scale3, + scale4, + scale5, + scale6, + scale7, + scale8, + scale9, + scale10, + scale11, + scale12, + scale13, + ): + db_res_samples = tuple( + [ + control1 * scale1, + control2 * scale2, + control3 * scale3, + control4 * scale4, + control5 * scale5, + control6 * scale6, + control7 * scale7, + control8 * scale8, + control9 * scale9, + control10 * scale10, + control11 * scale11, + control12 * scale12, + ] + ) + mb_res_samples = control13 * scale13 + samples = torch.cat([sample] * 2) + unet_out = self.unet.forward( + samples, + timestep, + encoder_hidden_states, + down_block_additional_residuals=db_res_samples, + mid_block_additional_residual=mb_res_samples, + return_dict=False, + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + def export_unet_model( unet_model, @@ -90,6 +171,7 @@ def export_unet_model( device=None, target_triple=None, max_alloc=None, + is_controlled=False, ): mapper = {} utils.save_external_weights( @@ -100,7 +182,7 @@ def export_unet_model( if hf_model_name == "stabilityai/stable-diffusion-2-1-base": encoder_hidden_states_sizes = (2, 77, 1024) - sample = (batch_size, unet_model.unet.config.in_channels, height // 8, width // 8) + sample = (batch_size, unet_model.unet.config.in_channels, height, width) class CompiledUnet(CompiledModule): if external_weights: @@ -120,8 +202,105 @@ def main( ): return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states) + class CompiledControlledUnet(CompiledModule): + if external_weights: + params = export_parameters( + unet_model, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(unet_model) + + def main( + self, + sample=AbstractTensor(*sample, dtype=torch.float32), + timestep=AbstractTensor(1, dtype=torch.float32), + encoder_hidden_states=AbstractTensor( + *encoder_hidden_states_sizes, dtype=torch.float32 + ), + control1=AbstractTensor(2, 320, height, width, dtype=torch.float32), + control2=AbstractTensor(2, 320, height, width, dtype=torch.float32), + control3=AbstractTensor(2, 320, height, width, dtype=torch.float32), + control4=AbstractTensor( + 2, 320, height // 2, width // 2, dtype=torch.float32 + ), + control5=AbstractTensor( + 2, 640, height // 2, width // 2, dtype=torch.float32 + ), + control6=AbstractTensor( + 2, 640, height // 2, width // 2, dtype=torch.float32 + ), + control7=AbstractTensor( + 2, 640, height // 4, width // 4, dtype=torch.float32 + ), + control8=AbstractTensor( + 2, 1280, height // 4, width // 4, dtype=torch.float32 + ), + control9=AbstractTensor( + 2, 1280, height // 4, width // 4, dtype=torch.float32 + ), + control10=AbstractTensor( + 2, 1280, height // 8, width // 8, dtype=torch.float32 + ), + control11=AbstractTensor( + 2, 1280, height // 8, width // 8, dtype=torch.float32 + ), + control12=AbstractTensor( + 2, 1280, height // 8, width // 8, dtype=torch.float32 + ), + control13=AbstractTensor( + 2, 1280, height // 8, width // 8, dtype=torch.float32 + ), + scale1=AbstractTensor(1, dtype=torch.float32), + scale2=AbstractTensor(1, dtype=torch.float32), + scale3=AbstractTensor(1, dtype=torch.float32), + scale4=AbstractTensor(1, dtype=torch.float32), + scale5=AbstractTensor(1, dtype=torch.float32), + scale6=AbstractTensor(1, dtype=torch.float32), + scale7=AbstractTensor(1, dtype=torch.float32), + scale8=AbstractTensor(1, dtype=torch.float32), + scale9=AbstractTensor(1, dtype=torch.float32), + scale10=AbstractTensor(1, dtype=torch.float32), + scale11=AbstractTensor(1, dtype=torch.float32), + scale12=AbstractTensor(1, dtype=torch.float32), + scale13=AbstractTensor(1, dtype=torch.float32), + ): + return jittable(unet_model.forward)( + sample, + timestep, + encoder_hidden_states, + control1, + control2, + control3, + control4, + control5, + control6, + control7, + control8, + control9, + control10, + control11, + control12, + control13, + scale1, + scale2, + scale3, + scale4, + scale5, + scale6, + scale7, + scale8, + scale9, + scale10, + scale11, + scale12, + scale13, + ) + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledUnet(context=Context(), import_to=import_to) + if is_controlled: + inst = CompiledControlledUnet(context=Context(), import_to=import_to) + else: + inst = CompiledUnet(context=Context(), import_to=import_to) module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = utils.create_safe_name(hf_model_name, "-unet") @@ -134,8 +313,9 @@ def main( if __name__ == "__main__": args = parser.parse_args() unet_model = UnetModel( - args.hf_model_name, + args.hf_model_name if not args.controlled else "CompVis/stable-diffusion-v1-4", args.hf_auth_token, + args.controlled, ) mod_str = export_unet_model( unet_model, @@ -150,6 +330,7 @@ def main( args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.controlled, ) safe_name = utils.create_safe_name(args.hf_model_name, "-unet") with open(f"{safe_name}.mlir", "w+") as f: