diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py index 2697a1e00..aec606e3e 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py @@ -177,6 +177,48 @@ def is_valid_file(arg): help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", ) +p.add_argument( + "--clip_device", + default=None, + type=str, + help="Device to run CLIP on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--mmdit_device", + default=None, + type=str, + help="Device to run MMDiT on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--vae_device", + default=None, + type=str, + help="Device to run VAE on. If None, defaults to the device specified in args.device.", +) + +p.add_argument( + "--clip_target", + default=None, + type=str, + help="IREE target for CLIP compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +p.add_argument( + "--mmdit_target", + default=None, + type=str, + help="IREE target for mmdit compilation. If None, defaults to the target specified by --iree_target_triple.", +) + +p.add_argument( + "--vae_target", + default=None, + type=str, + help="IREE target for vae compilation. If None, defaults to the target specified by --iree_target_triple.", +) + ############################################################################## # SD3 Modelling Options # These options are used to control model defining parameters for SD3. diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py index c16c22c3c..ed71a7f9a 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py @@ -25,27 +25,11 @@ import copy from datetime import datetime as dt -device_list = [ - "cpu", - "vulkan", - "cuda", - "rocm", -] - -rt_device_list = [ - "local-task", - "local-sync", - "vulkan", - "cuda", - "rocm", - "hip", -] - empty_pipe_dict = { - "vae": None, - "text_encoders": None, + "clip": None, "mmdit": None, "scheduler": None, + "vae": None, } EMPTY_FLAGS = { @@ -90,24 +74,40 @@ def __init__( self.batch_size = batch_size self.num_inference_steps = num_inference_steps self.devices = {} - if isinstance(self.device, dict): + if isinstance(device, dict): assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings." self.devices["clip"] = { "device": device["clip"], + "driver": utils.iree_device_map(device["clip"]), "target": iree_target_triple["clip"] } self.devices["mmdit"] = { "device": device["mmdit"], + "driver": utils.iree_device_map(device["mmdit"]), "target": iree_target_triple["mmdit"] } self.devices["vae"] = { "device": device["vae"], + "driver": utils.iree_device_map(device["vae"]), "target": iree_target_triple["vae"] } else: - self.devices["clip"] = device - self.devices["mmdit"] = device - self.devices["vae"] = device + assert isinstance(iree_target_triple, str), "Device and target triple must be both dicts or both strings." + self.devices["clip"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple + } + self.devices["mmdit"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple + } + self.devices["vae"] = { + "device": device, + "driver": utils.iree_device_map(device), + "target": iree_target_triple + } self.iree_target_triple = iree_target_triple self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS self.attn_spec = attn_spec @@ -176,6 +176,9 @@ def is_prepared(self, vmfbs, weights): val = None default_filepath = None continue + elif key == "clip": + val = "text_encoders" + default_filepath = os.path.join(self.pipeline_dir, val + ".vmfb") else: val = vmfbs[key] default_filepath = os.path.join(self.pipeline_dir, key + ".vmfb") @@ -197,7 +200,7 @@ def is_prepared(self, vmfbs, weights): default_name = os.path.join( self.external_weights_dir, w_key + "." + self.external_weights ) - if w_key == "text_encoders": + if w_key == "clip": default_name = os.path.join( self.external_weights_dir, f"sd3_clip_fp16.irpa" ) @@ -287,7 +290,7 @@ def export_submodel( if weights_only: input_mlir = { "vae": None, - "text_encoders": None, + "clip": None, "mmdit": None, "scheduler": None, } @@ -366,7 +369,7 @@ def export_submodel( ) del vae_torch return vae_vmfb, vae_external_weight_path - case "text_encoders": + case "clip": _, text_encoders_vmfb = sd3_text_encoders.export_text_encoders( self.hf_model_name, None, @@ -380,7 +383,7 @@ def export_submodel( self.ireec_flags["clip"], exit_on_vmfb=False, pipeline_dir=self.pipeline_dir, - input_mlir=input_mlir["text_encoders"], + input_mlir=input_mlir["clip"], attn_spec=self.attn_spec, output_batchsize=self.batch_size, ) @@ -392,7 +395,6 @@ def load_pipeline( self, vmfbs: dict, weights: dict, - rt_device: str | dict[str], compiled_pipeline: bool = False, split_scheduler: bool = True, extra_device_args: dict = {}, @@ -401,11 +403,12 @@ def load_pipeline( delegate = extra_device_args["npu_delegate_path"] else: delegate = None + self.runners = {} runners = {} load_start = time.time() runners["pipe"] = vmfbRunner( - rt_device, + self.devices["mmdit"]["driver"], vmfbs["mmdit"], weights["mmdit"], ) @@ -413,23 +416,24 @@ def load_pipeline( print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec") runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper( - rt_device, + self.devices["mmdit"]["driver"], vmfbs["scheduler"], ) sched_loaded = time.time() print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec") runners["vae"] = vmfbRunner( - rt_device, + self.devices["vae"]["driver"], vmfbs["vae"], - weights["vae"], + weights["vae"], + extra_plugin=delegate, ) vae_loaded = time.time() print("\n[LOG] VAE Decode loaded in ", vae_loaded - sched_loaded, "sec") - runners["text_encoders"] = vmfbRunner( - rt_device, - vmfbs["text_encoders"], - weights["text_encoders"], + runners["clip"] = vmfbRunner( + self.devices["clip"]["driver"], + vmfbs["clip"], + weights["clip"], ) clip_loaded = time.time() print("\n[LOG] Text Encoders loaded in ", clip_loaded - vae_loaded, "sec") @@ -500,29 +504,29 @@ def generate_images( uncond_input_ids_list = list(uncond_input_ids_dict.values()) text_encoders_inputs = [ ireert.asdevicearray( - self.runners["text_encoders"].config.device, text_input_ids_list[0] + self.runners["clip"].config.device, text_input_ids_list[0] ), ireert.asdevicearray( - self.runners["text_encoders"].config.device, text_input_ids_list[1] + self.runners["clip"].config.device, text_input_ids_list[1] ), ireert.asdevicearray( - self.runners["text_encoders"].config.device, text_input_ids_list[2] + self.runners["clip"].config.device, text_input_ids_list[2] ), ireert.asdevicearray( - self.runners["text_encoders"].config.device, uncond_input_ids_list[0] + self.runners["clip"].config.device, uncond_input_ids_list[0] ), ireert.asdevicearray( - self.runners["text_encoders"].config.device, uncond_input_ids_list[1] + self.runners["clip"].config.device, uncond_input_ids_list[1] ), ireert.asdevicearray( - self.runners["text_encoders"].config.device, uncond_input_ids_list[2] + self.runners["clip"].config.device, uncond_input_ids_list[2] ), ] # Tokenize prompt and negative prompt. encode_prompts_start = time.time() prompt_embeds, pooled_prompt_embeds = self.runners[ - "text_encoders" + "clip" ].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs) encode_prompts_end = time.time() @@ -690,6 +694,34 @@ def run_diffusers_cpu( mlirs = copy.deepcopy(map) vmfbs = copy.deepcopy(map) weights = copy.deepcopy(map) + + if any(x for x in [args.clip_device, args.mmdit_device, args.vae_device]): + assert all( + x for x in [args.clip_device, args.mmdit_device, args.vae_device] + ), "Please specify device for all submodels or pass --device for all submodels." + assert all( + x for x in [args.clip_target, args.mmdit_target, args.vae_target] + ), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels." + args.device = "hybrid" + args.iree_target_triple = "_".join([args.clip_target, args.mmdit_target, args.vae_target]) + else: + args.clip_device = args.device + args.mmdit_device = args.device + args.vae_device = args.device + args.clip_target = args.iree_target_triple + args.mmdit_target = args.iree_target_triple + args.vae_target = args.iree_target_triple + + devices = { + "clip": args.clip_device, + "mmdit": args.mmdit_device, + "vae": args.vae_device, + } + targets = { + "clip": args.clip_target, + "mmdit": args.mmdit_target, + "vae": args.vae_target, + } ireec_flags = { "clip": args.ireec_flags + args.clip_flags, "mmdit": args.ireec_flags + args.unet_flags, @@ -705,6 +737,7 @@ def run_diffusers_cpu( str(args.max_length), args.precision, args.device, + args.iree_target_triple, ] if args.decomp_attn: pipe_id_list.append("decomp") @@ -730,8 +763,8 @@ def run_diffusers_cpu( args.max_length, args.batch_size, args.num_inference_steps, - args.device, - args.iree_target_triple, + devices, + targets, ireec_flags, args.attn_spec, args.decomp_attn, @@ -747,7 +780,7 @@ def run_diffusers_cpu( vmfbs.pop("scheduler") weights.pop("scheduler") sd3_pipe.load_pipeline( - vmfbs, weights, args.rt_device, args.compiled_pipeline, args.split_scheduler + vmfbs, weights, args.compiled_pipeline, args.split_scheduler ) sd3_pipe.generate_images( args.prompt, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 4489141d6..a862e0d39 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -12,6 +12,17 @@ # DPMSolverSDEScheduler, ) +_IREE_DEVICE_MAP = { + "cpu": "local-task", + "cpu-task": "local-task", + "cpu-sync": "local-sync", + "cuda": "cuda", + "vulkan": "vulkan", + "metal": "metal", + "rocm": "rocm", + "hip": "hip", + "intel-gpu": "level_zero", +} # If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument. MI_flags = { "all": [ @@ -81,6 +92,19 @@ ], } +def iree_device_map(device): + uri_parts = device.split("://", 2) + iree_driver = ( + _IREE_DEVICE_MAP[uri_parts[0]] + if uri_parts[0] in _IREE_DEVICE_MAP + else uri_parts[0] + ) + if len(uri_parts) == 1: + return iree_driver + elif "rocm" in uri_parts: + return "rocm" + else: + return f"{iree_driver}://{uri_parts[1]}" def compile_to_vmfb( module_str,