Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch, height, width flags and rename vmfb's to different names #233

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions python/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import sys
import re

from iree import runtime as ireert
import iree.compiler as ireec
Expand Down Expand Up @@ -98,8 +97,7 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
inst = CompiledClip(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(hf_model_name, "-clip")
if compile_to != "vmfb":
return module_str, tokenizer
else:
Expand All @@ -113,8 +111,7 @@ def run_clip_vmfb_comparison(args):
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-clip")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
Expand Down Expand Up @@ -194,8 +191,7 @@ def run_clip_vmfb_comparison(args):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-clip")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
35 changes: 26 additions & 9 deletions python/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import sys
import re

from iree import runtime as ireert
from iree.compiler.ir import Context
Expand All @@ -30,6 +29,13 @@
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
Expand Down Expand Up @@ -76,6 +82,9 @@ def forward(self, sample, timestep, encoder_hidden_states):
def export_unet_model(
unet_model,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
Expand All @@ -93,6 +102,8 @@ def export_unet_model(
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states_sizes = (2, 77, 1024)

sample = (batch_size, unet_model.unet.in_channels, height // 8, width // 8)

class CompiledUnet(CompiledModule):
if external_weights:
params = export_parameters(
Expand All @@ -103,7 +114,7 @@ class CompiledUnet(CompiledModule):

def main(
self,
sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32),
sample=AbstractTensor(*sample, dtype=torch.float32),
timestep=AbstractTensor(1, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=torch.float32
Expand All @@ -115,8 +126,7 @@ def main(
inst = CompiledUnet(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(hf_model_name, "-unet")
if compile_to != "vmfb":
return module_str
else:
Expand All @@ -130,8 +140,7 @@ def run_unet_vmfb_comparison(unet_model, args):
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
Expand All @@ -153,7 +162,13 @@ def run_unet_vmfb_comparison(unet_model, args):
vm_modules=vm_modules,
config=config,
)
sample = torch.rand(1, 4, 64, 64, dtype=torch.float32)
sample = torch.rand(
args.batch_size,
unet_model.unet.in_channels,
args.height // 8,
args.width // 8,
dtype=torch.float32,
)
timestep = torch.zeros(1, dtype=torch.float32)
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
Expand Down Expand Up @@ -200,6 +215,9 @@ def run_unet_vmfb_comparison(unet_model, args):
mod_str = export_unet_model(
unet_model,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
Expand All @@ -208,8 +226,7 @@ def run_unet_vmfb_comparison(unet_model, args):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
7 changes: 7 additions & 0 deletions python/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import iree.compiler as ireec
import numpy as np
import safetensors
import re


def save_external_weights(
Expand Down Expand Up @@ -81,3 +82,9 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
f.write(flatbuffer_blob)
print("Saved to", safe_name + ".vmfb")
exit()


def create_safe_name(hf_model_name, model_name_str):
safe_name = hf_model_name.split("/")[-1].strip() + model_name_str
safe_name = re.sub("-", "_", safe_name)
return safe_name
35 changes: 26 additions & 9 deletions python/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import sys
import re

from iree import runtime as ireert
from iree.compiler.ir import Context
Expand All @@ -30,6 +29,13 @@
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
Expand Down Expand Up @@ -69,6 +75,9 @@ def forward(self, inp):
def export_vae_model(
vae_model,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
Expand All @@ -82,18 +91,19 @@ def export_vae_model(
mapper, vae_model, external_weights, external_weight_file
)

sample = (batch_size, 4, height // 8, width // 8)

class CompiledVae(CompiledModule):
params = export_parameters(vae_model)

def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)):
def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
return jittable(vae_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledVae(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(hf_model_name, "-vae")
if compile_to != "vmfb":
return module_str
else:
Expand All @@ -107,8 +117,7 @@ def run_vae_vmfb_comparison(vae_model, args):
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
Expand All @@ -130,7 +139,13 @@ def run_vae_vmfb_comparison(vae_model, args):
vm_modules=vm_modules,
config=config,
)
inp = torch.rand(1, 4, 64, 64, dtype=torch.float32)
inp = torch.rand(
args.batch_size,
4,
args.height // 8,
args.width // 8,
dtype=torch.float32,
)
device_inputs = [ireert.asdevicearray(config.device, inp)]

# Turbine output
Expand Down Expand Up @@ -165,6 +180,9 @@ def run_vae_vmfb_comparison(vae_model, args):
mod_str = export_vae_model(
vae_model,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
Expand All @@ -173,8 +191,7 @@ def run_vae_vmfb_comparison(vae_model, args):
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
safe_name = utils.create_safe_name(args.hf_model_name, "-vae")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
15 changes: 12 additions & 3 deletions python/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
arguments = {
"hf_auth_token": None,
"hf_model_name": "CompVis/stable-diffusion-v1-4",
"batch_size": 1,
"height": 512,
"width": 512,
"run_vmfb": True,
"compile_to": None,
"external_weight_file": "",
Expand Down Expand Up @@ -55,14 +58,17 @@ def testExportClipModel(self):
namespace = argparse.Namespace(**arguments)
clip.run_clip_vmfb_comparison(namespace)
os.remove("stable_diffusion_v1_4_clip.safetensors")
os.remove("stable_diffusion_v1_4.vmfb")
os.remove("stable_diffusion_v1_4_clip.vmfb")

def testExportUnetModel(self):
with self.assertRaises(SystemExit) as cm:
unet.export_unet_model(
unet_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
arguments["batch_size"],
arguments["height"],
arguments["width"],
None,
"vmfb",
"safetensors",
Expand All @@ -74,14 +80,17 @@ def testExportUnetModel(self):
namespace = argparse.Namespace(**arguments)
unet.run_unet_vmfb_comparison(unet_model, namespace)
os.remove("stable_diffusion_v1_4_unet.safetensors")
os.remove("stable_diffusion_v1_4.vmfb")
os.remove("stable_diffusion_v1_4_unet.vmfb")

def testExportVaeModel(self):
with self.assertRaises(SystemExit) as cm:
vae.export_vae_model(
vae_model,
# This is a public model, so no auth required
"CompVis/stable-diffusion-v1-4",
arguments["batch_size"],
arguments["height"],
arguments["width"],
None,
"vmfb",
"safetensors",
Expand All @@ -93,7 +102,7 @@ def testExportVaeModel(self):
namespace = argparse.Namespace(**arguments)
vae.run_vae_vmfb_comparison(vae_model, namespace)
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4.vmfb")
os.remove("stable_diffusion_v1_4_vae.vmfb")


if __name__ == "__main__":
Expand Down
Loading