diff --git a/python/turbine_models/custom_models/sd_inference/clip.py b/python/turbine_models/custom_models/sd_inference/clip.py index ec2dcb3fb..4b640617f 100644 --- a/python/turbine_models/custom_models/sd_inference/clip.py +++ b/python/turbine_models/custom_models/sd_inference/clip.py @@ -6,7 +6,6 @@ import os import sys -import re from iree import runtime as ireert import iree.compiler as ireec @@ -98,8 +97,7 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): inst = CompiledClip(context=Context(), import_to=import_to) module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(hf_model_name, "-clip") if compile_to != "vmfb": return module_str, tokenizer else: @@ -113,8 +111,7 @@ def run_clip_vmfb_comparison(args): index = ireert.ParameterIndex() index.load(args.external_weight_file) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(args.hf_model_name, "-clip") if args.vmfb_path: mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) elif os.path.exists(f"{safe_name}.vmfb"): @@ -194,8 +191,7 @@ def run_clip_vmfb_comparison(args): args.iree_target_triple, args.vulkan_max_allocation, ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(args.hf_model_name, "-clip") with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/unet.py b/python/turbine_models/custom_models/sd_inference/unet.py index 4c4d3c227..3372e3e05 100644 --- a/python/turbine_models/custom_models/sd_inference/unet.py +++ b/python/turbine_models/custom_models/sd_inference/unet.py @@ -6,7 +6,6 @@ import os import sys -import re from iree import runtime as ireert from iree.compiler.ir import Context @@ -30,6 +29,13 @@ help="HF model name", default="CompVis/stable-diffusion-v1-4", ) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") parser.add_argument("--run_vmfb", action="store_true") parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") parser.add_argument("--external_weight_file", type=str, default="") @@ -76,6 +82,9 @@ def forward(self, sample, timestep, encoder_hidden_states): def export_unet_model( unet_model, hf_model_name, + batch_size, + height, + width, hf_auth_token=None, compile_to="torch", external_weights=None, @@ -93,6 +102,8 @@ def export_unet_model( if hf_model_name == "stabilityai/stable-diffusion-2-1-base": encoder_hidden_states_sizes = (2, 77, 1024) + sample = (batch_size, unet_model.unet.in_channels, height // 8, width // 8) + class CompiledUnet(CompiledModule): if external_weights: params = export_parameters( @@ -103,7 +114,7 @@ class CompiledUnet(CompiledModule): def main( self, - sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32), + sample=AbstractTensor(*sample, dtype=torch.float32), timestep=AbstractTensor(1, dtype=torch.float32), encoder_hidden_states=AbstractTensor( *encoder_hidden_states_sizes, dtype=torch.float32 @@ -115,8 +126,7 @@ def main( inst = CompiledUnet(context=Context(), import_to=import_to) module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(hf_model_name, "-unet") if compile_to != "vmfb": return module_str else: @@ -130,8 +140,7 @@ def run_unet_vmfb_comparison(unet_model, args): index = ireert.ParameterIndex() index.load(args.external_weight_file) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(args.hf_model_name, "-unet") if args.vmfb_path: mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) elif os.path.exists(f"{safe_name}.vmfb"): @@ -153,7 +162,13 @@ def run_unet_vmfb_comparison(unet_model, args): vm_modules=vm_modules, config=config, ) - sample = torch.rand(1, 4, 64, 64, dtype=torch.float32) + sample = torch.rand( + args.batch_size, + unet_model.unet.in_channels, + args.height // 8, + args.width // 8, + dtype=torch.float32, + ) timestep = torch.zeros(1, dtype=torch.float32) if args.hf_model_name == "CompVis/stable-diffusion-v1-4": encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) @@ -200,6 +215,9 @@ def run_unet_vmfb_comparison(unet_model, args): mod_str = export_unet_model( unet_model, args.hf_model_name, + args.batch_size, + args.height, + args.width, args.hf_auth_token, args.compile_to, args.external_weights, @@ -208,8 +226,7 @@ def run_unet_vmfb_comparison(unet_model, args): args.iree_target_triple, args.vulkan_max_allocation, ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(args.hf_model_name, "-unet") with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/utils.py b/python/turbine_models/custom_models/sd_inference/utils.py index 459d09aa8..0cba5338e 100644 --- a/python/turbine_models/custom_models/sd_inference/utils.py +++ b/python/turbine_models/custom_models/sd_inference/utils.py @@ -1,6 +1,7 @@ import iree.compiler as ireec import numpy as np import safetensors +import re def save_external_weights( @@ -81,3 +82,9 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): f.write(flatbuffer_blob) print("Saved to", safe_name + ".vmfb") exit() + + +def create_safe_name(hf_model_name, model_name_str): + safe_name = hf_model_name.split("/")[-1].strip() + model_name_str + safe_name = re.sub("-", "_", safe_name) + return safe_name diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py index cf3c587a1..b86d88ca5 100644 --- a/python/turbine_models/custom_models/sd_inference/vae.py +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -6,7 +6,6 @@ import os import sys -import re from iree import runtime as ireert from iree.compiler.ir import Context @@ -30,6 +29,13 @@ help="HF model name", default="CompVis/stable-diffusion-v1-4", ) +parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for inference" +) +parser.add_argument( + "--height", type=int, default=512, help="Height of Stable Diffusion" +) +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") parser.add_argument("--run_vmfb", action="store_true") parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") parser.add_argument("--external_weight_file", type=str, default="") @@ -69,6 +75,9 @@ def forward(self, inp): def export_vae_model( vae_model, hf_model_name, + batch_size, + height, + width, hf_auth_token=None, compile_to="torch", external_weights=None, @@ -82,18 +91,19 @@ def export_vae_model( mapper, vae_model, external_weights, external_weight_file ) + sample = (batch_size, 4, height // 8, width // 8) + class CompiledVae(CompiledModule): params = export_parameters(vae_model) - def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)): + def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): return jittable(vae_model.forward)(inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledVae(context=Context(), import_to=import_to) module_str = str(CompiledModule.get_mlir_module(inst)) - safe_name = hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(hf_model_name, "-vae") if compile_to != "vmfb": return module_str else: @@ -107,8 +117,7 @@ def run_vae_vmfb_comparison(vae_model, args): index = ireert.ParameterIndex() index.load(args.external_weight_file) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(args.hf_model_name, "-vae") if args.vmfb_path: mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) elif os.path.exists(f"{safe_name}.vmfb"): @@ -130,7 +139,13 @@ def run_vae_vmfb_comparison(vae_model, args): vm_modules=vm_modules, config=config, ) - inp = torch.rand(1, 4, 64, 64, dtype=torch.float32) + inp = torch.rand( + args.batch_size, + 4, + args.height // 8, + args.width // 8, + dtype=torch.float32, + ) device_inputs = [ireert.asdevicearray(config.device, inp)] # Turbine output @@ -165,6 +180,9 @@ def run_vae_vmfb_comparison(vae_model, args): mod_str = export_vae_model( vae_model, args.hf_model_name, + args.batch_size, + args.height, + args.width, args.hf_auth_token, args.compile_to, args.external_weights, @@ -173,8 +191,7 @@ def run_vae_vmfb_comparison(vae_model, args): args.iree_target_triple, args.vulkan_max_allocation, ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) + safe_name = utils.create_safe_name(args.hf_model_name, "-vae") with open(f"{safe_name}.mlir", "w+") as f: f.write(mod_str) print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py index c3834840d..e01027fc5 100644 --- a/python/turbine_models/tests/sd_test.py +++ b/python/turbine_models/tests/sd_test.py @@ -14,6 +14,9 @@ arguments = { "hf_auth_token": None, "hf_model_name": "CompVis/stable-diffusion-v1-4", + "batch_size": 1, + "height": 512, + "width": 512, "run_vmfb": True, "compile_to": None, "external_weight_file": "", @@ -55,7 +58,7 @@ def testExportClipModel(self): namespace = argparse.Namespace(**arguments) clip.run_clip_vmfb_comparison(namespace) os.remove("stable_diffusion_v1_4_clip.safetensors") - os.remove("stable_diffusion_v1_4.vmfb") + os.remove("stable_diffusion_v1_4_clip.vmfb") def testExportUnetModel(self): with self.assertRaises(SystemExit) as cm: @@ -63,6 +66,9 @@ def testExportUnetModel(self): unet_model, # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], None, "vmfb", "safetensors", @@ -74,7 +80,7 @@ def testExportUnetModel(self): namespace = argparse.Namespace(**arguments) unet.run_unet_vmfb_comparison(unet_model, namespace) os.remove("stable_diffusion_v1_4_unet.safetensors") - os.remove("stable_diffusion_v1_4.vmfb") + os.remove("stable_diffusion_v1_4_unet.vmfb") def testExportVaeModel(self): with self.assertRaises(SystemExit) as cm: @@ -82,6 +88,9 @@ def testExportVaeModel(self): vae_model, # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", + arguments["batch_size"], + arguments["height"], + arguments["width"], None, "vmfb", "safetensors", @@ -93,7 +102,7 @@ def testExportVaeModel(self): namespace = argparse.Namespace(**arguments) vae.run_vae_vmfb_comparison(vae_model, namespace) os.remove("stable_diffusion_v1_4_vae.safetensors") - os.remove("stable_diffusion_v1_4.vmfb") + os.remove("stable_diffusion_v1_4_vae.vmfb") if __name__ == "__main__":