Skip to content

Commit

Permalink
Add flags and rename vmfb's to different names
Browse files Browse the repository at this point in the history
  • Loading branch information
aviator19941 committed Dec 12, 2023
1 parent 5945477 commit a3c1a24
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 28 deletions.
10 changes: 3 additions & 7 deletions python/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import sys
import re

from iree import runtime as ireert
import iree.compiler as ireec
Expand Down Expand Up @@ -98,8 +97,7 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
inst = CompiledClip(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(hf_model_name, "-clip")
if compile_to != "vmfb":
return module_str, tokenizer
else:
Expand All @@ -113,8 +111,7 @@ def run_clip_vmfb_comparison(args):
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-clip")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
Expand Down Expand Up @@ -194,8 +191,7 @@ def run_clip_vmfb_comparison(args):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-clip")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
35 changes: 26 additions & 9 deletions python/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import sys
import re

from iree import runtime as ireert
from iree.compiler.ir import Context
Expand All @@ -30,6 +29,13 @@
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
Expand Down Expand Up @@ -76,6 +82,9 @@ def forward(self, sample, timestep, encoder_hidden_states):
def export_unet_model(
unet_model,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
Expand All @@ -93,6 +102,8 @@ def export_unet_model(
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states_sizes = (2, 77, 1024)

sample = (batch_size, unet_model.unet.in_channels, height // 8, width // 8)

class CompiledUnet(CompiledModule):
if external_weights:
params = export_parameters(
Expand All @@ -103,7 +114,7 @@ class CompiledUnet(CompiledModule):

def main(
self,
sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32),
sample=AbstractTensor(*sample, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=torch.float32
Expand All @@ -115,8 +126,7 @@ def main(
inst = CompiledUnet(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(hf_model_name, "-unet")
if compile_to != "vmfb":
return module_str
else:
Expand All @@ -130,8 +140,7 @@ def run_unet_vmfb_comparison(unet_model, args):
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
Expand All @@ -153,7 +162,13 @@ def run_unet_vmfb_comparison(unet_model, args):
vm_modules=vm_modules,
config=config,
)
sample = torch.rand(1, 4, 64, 64, dtype=torch.float32)
sample = torch.rand(
args.batch_size,
unet_model.unet.in_channels,
args.height // 8,
args.width // 8,
dtype=torch.float32,
)
timestep = torch.zeros(1, dtype=torch.float32)
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
Expand Down Expand Up @@ -200,6 +215,9 @@ def run_unet_vmfb_comparison(unet_model, args):
mod_str = export_unet_model(
unet_model,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
Expand All @@ -208,8 +226,7 @@ def run_unet_vmfb_comparison(unet_model, args):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
7 changes: 7 additions & 0 deletions python/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import iree.compiler as ireec
import numpy as np
import safetensors
import re


def save_external_weights(
Expand Down Expand Up @@ -81,3 +82,9 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
f.write(flatbuffer_blob)
print("Saved to", safe_name + ".vmfb")
exit()


def create_safe_name(hf_model_name, model_name_str):
safe_name = hf_model_name.split("/")[-1].strip() + model_name_str
safe_name = re.sub("-", "_", safe_name)
return safe_name
35 changes: 26 additions & 9 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import sys
import re

from iree import runtime as ireert
from iree.compiler.ir import Context
Expand All @@ -30,6 +29,13 @@
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
Expand Down Expand Up @@ -69,6 +75,9 @@ def forward(self, inp):
def export_vae_model(
vae_model,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
Expand All @@ -82,18 +91,19 @@ def export_vae_model(
mapper, vae_model, external_weights, external_weight_file
)

sample = (batch_size, 4, height // 8, width // 8)

class CompiledVae(CompiledModule):
params = export_parameters(vae_model)

def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)):
def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
return jittable(vae_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledVae(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(hf_model_name, "-vae")
if compile_to != "vmfb":
return module_str
else:
Expand All @@ -107,8 +117,7 @@ def run_vae_vmfb_comparison(vae_model, args):
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
Expand All @@ -130,7 +139,13 @@ def run_vae_vmfb_comparison(vae_model, args):
vm_modules=vm_modules,
config=config,
)
inp = torch.rand(1, 4, 64, 64, dtype=torch.float32)
inp = torch.rand(
args.batch_size,
4,
args.height // 8,
args.width // 8,
dtype=torch.float32,
)
device_inputs = [ireert.asdevicearray(config.device, inp)]

# Turbine output
Expand Down Expand Up @@ -165,6 +180,9 @@ def run_vae_vmfb_comparison(vae_model, args):
mod_str = export_vae_model(
vae_model,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
Expand All @@ -173,8 +191,7 @@ def run_vae_vmfb_comparison(vae_model, args):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
15 changes: 12 additions & 3 deletions python/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
arguments = {
"hf_auth_token": None,
"hf_model_name": "CompVis/stable-diffusion-v1-4",
"batch_size": 1,
"height": 512,
"width": 512,
"run_vmfb": True,
"compile_to": None,
"external_weight_file": "",
Expand Down Expand Up @@ -55,14 +58,17 @@ def testExportClipModel(self):
namespace = argparse.Namespace(**arguments)
clip.run_clip_vmfb_comparison(namespace)
os.remove("stable_diffusion_v1_4_clip.safetensors")
os.remove("stable_diffusion_v1_4.vmfb")
os.remove("stable_diffusion_v1_4_clip.vmfb")

def testExportUnetModel(self):
with self.assertRaises(SystemExit) as cm:
unet.export_unet_model(
unet_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
arguments["batch_size"],
arguments["height"],
arguments["width"],
None,
"vmfb",
"safetensors",
Expand All @@ -74,14 +80,17 @@ def testExportUnetModel(self):
namespace = argparse.Namespace(**arguments)
unet.run_unet_vmfb_comparison(unet_model, namespace)
os.remove("stable_diffusion_v1_4_unet.safetensors")
os.remove("stable_diffusion_v1_4.vmfb")
os.remove("stable_diffusion_v1_4_unet.vmfb")

def testExportVaeModel(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
arguments["batch_size"],
arguments["height"],
arguments["width"],
None,
"vmfb",
"safetensors",
Expand All @@ -93,7 +102,7 @@ def testExportVaeModel(self):
namespace = argparse.Namespace(**arguments)
vae.run_vae_vmfb_comparison(vae_model, namespace)
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4.vmfb")
os.remove("stable_diffusion_v1_4_vae.vmfb")


if __name__ == "__main__":
Expand Down

0 comments on commit a3c1a24

Please sign in to comment.