Skip to content

Commit

Permalink
Fixes for multi-device (SD3)
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 17, 2024
1 parent b793686 commit 7754609
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 45 deletions.
42 changes: 42 additions & 0 deletions models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,48 @@ def is_valid_file(arg):
help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.",
)

p.add_argument(
"--clip_device",
default=None,
type=str,
help="Device to run CLIP on. If None, defaults to the device specified in args.device.",
)

p.add_argument(
"--mmdit_device",
default=None,
type=str,
help="Device to run MMDiT on. If None, defaults to the device specified in args.device.",
)

p.add_argument(
"--vae_device",
default=None,
type=str,
help="Device to run VAE on. If None, defaults to the device specified in args.device.",
)

p.add_argument(
"--clip_target",
default=None,
type=str,
help="IREE target for CLIP compilation. If None, defaults to the target specified by --iree_target_triple.",
)

p.add_argument(
"--mmdit_target",
default=None,
type=str,
help="IREE target for mmdit compilation. If None, defaults to the target specified by --iree_target_triple.",
)

p.add_argument(
"--vae_target",
default=None,
type=str,
help="IREE target for vae compilation. If None, defaults to the target specified by --iree_target_triple.",
)

##############################################################################
# SD3 Modelling Options
# These options are used to control model defining parameters for SD3.
Expand Down
123 changes: 78 additions & 45 deletions models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,11 @@
import copy
from datetime import datetime as dt

device_list = [
"cpu",
"vulkan",
"cuda",
"rocm",
]

rt_device_list = [
"local-task",
"local-sync",
"vulkan",
"cuda",
"rocm",
"hip",
]

empty_pipe_dict = {
"vae": None,
"text_encoders": None,
"clip": None,
"mmdit": None,
"scheduler": None,
"vae": None,
}

EMPTY_FLAGS = {
Expand Down Expand Up @@ -90,24 +74,40 @@ def __init__(
self.batch_size = batch_size
self.num_inference_steps = num_inference_steps
self.devices = {}
if isinstance(self.device, dict):
if isinstance(device, dict):
assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings."
self.devices["clip"] = {
"device": device["clip"],
"driver": utils.iree_device_map(device["clip"]),
"target": iree_target_triple["clip"]
}
self.devices["mmdit"] = {
"device": device["mmdit"],
"driver": utils.iree_device_map(device["mmdit"]),
"target": iree_target_triple["mmdit"]
}
self.devices["vae"] = {
"device": device["vae"],
"driver": utils.iree_device_map(device["vae"]),
"target": iree_target_triple["vae"]
}
else:
self.devices["clip"] = device
self.devices["mmdit"] = device
self.devices["vae"] = device
assert isinstance(iree_target_triple, str), "Device and target triple must be both dicts or both strings."
self.devices["clip"] = {
"device": device,
"driver": utils.iree_device_map(device),
"target": iree_target_triple
}
self.devices["mmdit"] = {
"device": device,
"driver": utils.iree_device_map(device),
"target": iree_target_triple
}
self.devices["vae"] = {
"device": device,
"driver": utils.iree_device_map(device),
"target": iree_target_triple
}
self.iree_target_triple = iree_target_triple
self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
self.attn_spec = attn_spec
Expand Down Expand Up @@ -176,6 +176,9 @@ def is_prepared(self, vmfbs, weights):
val = None
default_filepath = None
continue
elif key == "clip":
val = "text_encoders"
default_filepath = os.path.join(self.pipeline_dir, val + ".vmfb")
else:
val = vmfbs[key]
default_filepath = os.path.join(self.pipeline_dir, key + ".vmfb")
Expand All @@ -197,7 +200,7 @@ def is_prepared(self, vmfbs, weights):
default_name = os.path.join(
self.external_weights_dir, w_key + "." + self.external_weights
)
if w_key == "text_encoders":
if w_key == "clip":
default_name = os.path.join(
self.external_weights_dir, f"sd3_clip_fp16.irpa"
)
Expand Down Expand Up @@ -287,7 +290,7 @@ def export_submodel(
if weights_only:
input_mlir = {
"vae": None,
"text_encoders": None,
"clip": None,
"mmdit": None,
"scheduler": None,
}
Expand Down Expand Up @@ -366,7 +369,7 @@ def export_submodel(
)
del vae_torch
return vae_vmfb, vae_external_weight_path
case "text_encoders":
case "clip":
_, text_encoders_vmfb = sd3_text_encoders.export_text_encoders(
self.hf_model_name,
None,
Expand All @@ -380,7 +383,7 @@ def export_submodel(
self.ireec_flags["clip"],
exit_on_vmfb=False,
pipeline_dir=self.pipeline_dir,
input_mlir=input_mlir["text_encoders"],
input_mlir=input_mlir["clip"],
attn_spec=self.attn_spec,
output_batchsize=self.batch_size,
)
Expand All @@ -392,7 +395,6 @@ def load_pipeline(
self,
vmfbs: dict,
weights: dict,
rt_device: str | dict[str],
compiled_pipeline: bool = False,
split_scheduler: bool = True,
extra_device_args: dict = {},
Expand All @@ -401,35 +403,37 @@ def load_pipeline(
delegate = extra_device_args["npu_delegate_path"]
else:
delegate = None

self.runners = {}
runners = {}
load_start = time.time()
runners["pipe"] = vmfbRunner(
rt_device,
self.devices["mmdit"]["driver"],
vmfbs["mmdit"],
weights["mmdit"],
)
unet_loaded = time.time()
print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec")

runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
rt_device,
self.devices["mmdit"]["driver"],
vmfbs["scheduler"],
)

sched_loaded = time.time()
print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec")
runners["vae"] = vmfbRunner(
rt_device,
self.devices["vae"]["driver"],
vmfbs["vae"],
weights["vae"],
weights["vae"],
extra_plugin=delegate,
)
vae_loaded = time.time()
print("\n[LOG] VAE Decode loaded in ", vae_loaded - sched_loaded, "sec")
runners["text_encoders"] = vmfbRunner(
rt_device,
vmfbs["text_encoders"],
weights["text_encoders"],
runners["clip"] = vmfbRunner(
self.devices["clip"]["driver"],
vmfbs["clip"],
weights["clip"],
)
clip_loaded = time.time()
print("\n[LOG] Text Encoders loaded in ", clip_loaded - vae_loaded, "sec")
Expand Down Expand Up @@ -500,29 +504,29 @@ def generate_images(
uncond_input_ids_list = list(uncond_input_ids_dict.values())
text_encoders_inputs = [
ireert.asdevicearray(
self.runners["text_encoders"].config.device, text_input_ids_list[0]
self.runners["clip"].config.device, text_input_ids_list[0]
),
ireert.asdevicearray(
self.runners["text_encoders"].config.device, text_input_ids_list[1]
self.runners["clip"].config.device, text_input_ids_list[1]
),
ireert.asdevicearray(
self.runners["text_encoders"].config.device, text_input_ids_list[2]
self.runners["clip"].config.device, text_input_ids_list[2]
),
ireert.asdevicearray(
self.runners["text_encoders"].config.device, uncond_input_ids_list[0]
self.runners["clip"].config.device, uncond_input_ids_list[0]
),
ireert.asdevicearray(
self.runners["text_encoders"].config.device, uncond_input_ids_list[1]
self.runners["clip"].config.device, uncond_input_ids_list[1]
),
ireert.asdevicearray(
self.runners["text_encoders"].config.device, uncond_input_ids_list[2]
self.runners["clip"].config.device, uncond_input_ids_list[2]
),
]

# Tokenize prompt and negative prompt.
encode_prompts_start = time.time()
prompt_embeds, pooled_prompt_embeds = self.runners[
"text_encoders"
"clip"
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
encode_prompts_end = time.time()

Expand Down Expand Up @@ -690,6 +694,34 @@ def run_diffusers_cpu(
mlirs = copy.deepcopy(map)
vmfbs = copy.deepcopy(map)
weights = copy.deepcopy(map)

if any(x for x in [args.clip_device, args.mmdit_device, args.vae_device]):
assert all(
x for x in [args.clip_device, args.mmdit_device, args.vae_device]
), "Please specify device for all submodels or pass --device for all submodels."
assert all(
x for x in [args.clip_target, args.mmdit_target, args.vae_target]
), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels."
args.device = "hybrid"
args.iree_target_triple = "_".join([args.clip_target, args.mmdit_target, args.vae_target])
else:
args.clip_device = args.device
args.mmdit_device = args.device
args.vae_device = args.device
args.clip_target = args.iree_target_triple
args.mmdit_target = args.iree_target_triple
args.vae_target = args.iree_target_triple

devices = {
"clip": args.clip_device,
"mmdit": args.mmdit_device,
"vae": args.vae_device,
}
targets = {
"clip": args.clip_target,
"mmdit": args.mmdit_target,
"vae": args.vae_target,
}
ireec_flags = {
"clip": args.ireec_flags + args.clip_flags,
"mmdit": args.ireec_flags + args.unet_flags,
Expand All @@ -705,6 +737,7 @@ def run_diffusers_cpu(
str(args.max_length),
args.precision,
args.device,
args.iree_target_triple,
]
if args.decomp_attn:
pipe_id_list.append("decomp")
Expand All @@ -730,8 +763,8 @@ def run_diffusers_cpu(
args.max_length,
args.batch_size,
args.num_inference_steps,
args.device,
args.iree_target_triple,
devices,
targets,
ireec_flags,
args.attn_spec,
args.decomp_attn,
Expand All @@ -747,7 +780,7 @@ def run_diffusers_cpu(
vmfbs.pop("scheduler")
weights.pop("scheduler")
sd3_pipe.load_pipeline(
vmfbs, weights, args.rt_device, args.compiled_pipeline, args.split_scheduler
vmfbs, weights, args.compiled_pipeline, args.split_scheduler
)
sd3_pipe.generate_images(
args.prompt,
Expand Down
24 changes: 24 additions & 0 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
# DPMSolverSDEScheduler,
)

_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "metal",
"rocm": "rocm",
"hip": "hip",
"intel-gpu": "level_zero",
}
# If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument.
MI_flags = {
"all": [
Expand Down Expand Up @@ -81,6 +92,19 @@
],
}

def iree_device_map(device):
uri_parts = device.split("://", 2)
iree_driver = (
_IREE_DEVICE_MAP[uri_parts[0]]
if uri_parts[0] in _IREE_DEVICE_MAP
else uri_parts[0]
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"

def compile_to_vmfb(
module_str,
Expand Down

0 comments on commit 7754609

Please sign in to comment.