diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 03872dea..2b2b6062 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -70,7 +70,6 @@ jobs: pytest -v models/turbine_models/tests/sd_test.py pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 - pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2 - pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 \ No newline at end of file + pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5 diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index 5c02649a..3102ac3e 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -84,7 +84,12 @@ class PipelineComponent: """ def __init__( - self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False + self, + printer, + dest_type="devicearray", + dest_dtype="float16", + benchmark=False, + save_outputs=False, ): self.runner = None self.module_name = None @@ -92,6 +97,8 @@ def __init__( self.metadata = None self.printer = printer self.benchmark = benchmark + self.save_outputs = save_outputs + self.output_counter = 0 self.dest_type = dest_type self.dest_dtype = dest_dtype @@ -218,6 +225,16 @@ def _output_cast(self, output): case _: return output + def save_output(self, function_name, output): + if isinstance(output, tuple) or isinstance(output, list): + for i in output: + self.save_output(function_name, i) + else: + np.save( + f"{function_name}_output_{self.output_counter}.npy", output.to_host() + ) + self.output_counter += 1 + def _run(self, function_name, inputs: list): return self.module[function_name](*inputs) @@ -239,6 +256,8 @@ def __call__(self, function_name, inputs: list): output = self._run_and_benchmark(function_name, inputs) else: output = self._run(function_name, inputs) + if self.save_outputs: + self.save_output(function_name, output) output = self._output_cast(output) return output @@ -340,6 +359,7 @@ def __init__( hf_model_name: str | dict[str] = None, benchmark: bool | dict[bool] = False, verbose: bool = False, + save_outputs: bool | dict[bool] = False, common_export_args: dict = {}, ): self.map = model_map @@ -374,6 +394,7 @@ def __init__( "external_weights": external_weights, "hf_model_name": hf_model_name, "benchmark": benchmark, + "save_outputs": save_outputs, } for arg in map_arguments.keys(): self.map = merge_arg_into_map(self.map, map_arguments[arg], arg) @@ -391,7 +412,8 @@ def __init__( ) for submodel in self.map.keys(): for key, value in map_arguments.items(): - self.map = merge_export_arg(self.map, value, key) + if key not in ["benchmark", "save_outputs"]: + self.map = merge_export_arg(self.map, value, key) for key, value in self.map[submodel].get("export_args", {}).items(): if key == "hf_model_name": self.map[submodel]["keywords"].append( @@ -539,7 +561,11 @@ def is_prepared(self, vmfbs, weights): avail_files = os.listdir(self.external_weights_dir) candidates = [] for filename in avail_files: - if all(str(x) in filename for x in w_keywords): + if all( + str(x) in filename + or str(x) == os.path.join(self.external_weights_dir, filename) + for x in w_keywords + ): candidates.append( os.path.join(self.external_weights_dir, filename) ) @@ -723,7 +749,7 @@ def export_submodel( def load_map(self): for submodel in self.map.keys(): if not self.map[submodel]["load"]: - self.printer.print("Skipping load for ", submodel) + self.printer.print(f"Skipping load for {submodel}") continue self.load_submodel(submodel) @@ -739,6 +765,7 @@ def load_submodel(self, submodel): printer=self.printer, dest_type=dest_type, benchmark=self.map[submodel].get("benchmark", False), + save_outputs=self.map[submodel].get("save_outputs", False), ) self.map[submodel]["runner"].load( self.map[submodel]["driver"], @@ -751,6 +778,10 @@ def load_submodel(self, submodel): def unload_submodel(self, submodel): self.map[submodel]["runner"].unload() + self.map[submodel]["vmfb"] = None + self.map[submodel]["mlir"] = None + self.map[submodel]["weights"] = None + self.map[submodel]["export_args"]["input_mlir"] = None setattr(self, submodel, None) diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 5e025a4d..a852bf46 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -151,6 +151,12 @@ def is_valid_file(arg): help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.", ) +p.add_argument( + "--save_outputs", + type=str, + default=None, + help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.", +) ############################################################################## # SDXL Modelling Options # These options are used to control model defining parameters for SDXL. diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 277f74cb..a322cb08 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -190,6 +190,7 @@ def get_sd_model_map(hf_model_name): "stabilityai/sdxl-turbo", "stabilityai/stable-diffusion-xl-base-1.0", "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe", + "/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe", ]: return sdxl_model_map elif "stabilityai/stable-diffusion-3" in name: @@ -233,6 +234,12 @@ def __init__( benchmark: bool | dict[bool] = False, verbose: bool = False, batch_prompts: bool = False, + punet_quant_paths: dict[str] = None, + vae_weight_path: str = None, + vae_harness: bool = True, + add_tk_kernels: bool = False, + tk_kernels_dir: str | dict[str] = None, + save_outputs: bool | dict[bool] = False, ): common_export_args = { "hf_model_name": None, @@ -243,11 +250,11 @@ def __init__( "exit_on_vmfb": False, "pipeline_dir": pipeline_dir, "input_mlir": None, - "attn_spec": None, + "attn_spec": attn_spec, "external_weights": None, "external_weight_path": None, } - sd_model_map = get_sd_model_map(hf_model_name) + sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name)) for submodel in sd_model_map: if "load" not in sd_model_map[submodel]: sd_model_map[submodel]["load"] = True @@ -281,6 +288,7 @@ def __init__( hf_model_name, benchmark, verbose, + save_outputs, common_export_args, ) for submodel in sd_model_map: @@ -303,6 +311,7 @@ def __init__( self.cpu_scheduling = cpu_scheduling self.scheduler_id = scheduler_id self.num_inference_steps = num_inference_steps + self.punet_quant_paths = punet_quant_paths self.text_encoder = None self.unet = None @@ -311,6 +320,8 @@ def __init__( self.scheduler = None self.split_scheduler = True + self.add_tk_kernels = add_tk_kernels + self.tk_kernels_dir = tk_kernels_dir self.base_model_name = ( hf_model_name @@ -339,6 +350,9 @@ def __init__( self.scheduler_device = self.map["unet"]["device"] self.scheduler_driver = self.map["unet"]["driver"] self.scheduler_target = self.map["unet"]["target"] + if vae_weight_path is not None: + self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path + self.map["vae"]["export_args"]["vae_harness"] = vae_harness elif not self.is_sd3: self.tokenizer = CLIPTokenizer.from_pretrained( self.base_model_name, subfolder="tokenizer" @@ -351,23 +365,31 @@ def __init__( self.latents_dtype = torch_dtypes[self.latents_precision] self.use_i8_punet = self.use_punet = use_i8_punet + if self.use_punet: + self.setup_punet() + else: + self.map["unet"]["keywords"].append("!punet") + self.map["unet"]["function_name"] = "run_forward" + + def setup_punet(self): if self.use_i8_punet: + if self.add_tk_kernels: + self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels + self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir self.map["unet"]["export_args"]["precision"] = "i8" - self.map["unet"]["export_args"]["use_punet"] = True - self.map["unet"]["use_weights_for_export"] = True - self.map["unet"]["keywords"].append("punet") - self.map["unet"]["module_name"] = "compiled_punet" - self.map["unet"]["function_name"] = "main" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" ) + self.map["unet"]["export_args"]["quant_paths"] = self.punet_quant_paths for idx, word in enumerate(self.map["unet"]["keywords"]): if word in ["fp32", "fp16"]: self.map["unet"]["keywords"][idx] = "i8" break - else: - self.map["unet"]["keywords"].append("!punet") - self.map["unet"]["function_name"] = "run_forward" + self.map["unet"]["export_args"]["use_punet"] = True + self.map["unet"]["use_weights_for_export"] = True + self.map["unet"]["keywords"].append("punet") + self.map["unet"]["module_name"] = "compiled_punet" + self.map["unet"]["function_name"] = "main" # LOAD @@ -376,10 +398,6 @@ def load_scheduler( scheduler_id: str, steps: int = 30, ): - if self.is_sd3: - scheduler_device = self.mmdit.device - else: - scheduler_device = self.unet.device if not self.cpu_scheduling: self.map["scheduler"] = { "module_name": "compiled_scheduler", @@ -426,6 +444,10 @@ def load_scheduler( print("JIT export of scheduler failed. Loading CPU scheduler.") self.cpu_scheduling = True if self.cpu_scheduling: + if self.is_sd3: + scheduler_device = self.mmdit.device + else: + scheduler_device = self.unet.device scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id) self.scheduler = schedulers.SharkSchedulerCPUWrapper( scheduler, @@ -481,9 +503,12 @@ def prepare_latents( elif self.is_sdxl and self.cpu_scheduling: self.scheduler.do_guidance = False self.scheduler.repeat_sample = False - sample, add_time_ids, step_indexes, timesteps = ( - self.scheduler.initialize_sdxl(noise, num_inference_steps) - ) + ( + sample, + add_time_ids, + step_indexes, + timesteps, + ) = self.scheduler.initialize_sdxl(noise, num_inference_steps) return sample, add_time_ids, step_indexes, timesteps elif self.is_sdxl: return self.scheduler("run_initialize", noise) @@ -565,9 +590,11 @@ def _produce_latents_sdxl( [guidance_scale], dtype=self.map["unet"]["np_dtype"], ) + # Disable progress bar if we aren't in verbose mode or if we're printing + # benchmark latencies for unet. for i, t in tqdm( enumerate(timesteps), - disable=(self.map["unet"].get("benchmark") and self.verbose), + disable=(self.map["unet"].get("benchmark") or not self.verbose), ): if self.cpu_scheduling: latent_model_input, t = self.scheduler.scale_model_input( @@ -720,6 +747,14 @@ def numpy_to_pil_image(images): benchmark[i] = True else: benchmark = False + if args.save_outputs: + if args.save_outputs.lower() == "all": + save_outputs = True + else: + for i in args.save_outputs.split(","): + save_outputs[i] = True + else: + save_outputs = False if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]): args.decomp_attn = { "text_encoder": args.decomp_attn, @@ -750,6 +785,7 @@ def numpy_to_pil_image(images): args.use_i8_punet, benchmark, args.verbose, + save_outputs=save_outputs, ) sd_pipe.prepare_all() sd_pipe.load_map() diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index cc8591b9..84d9cb3b 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -5,6 +5,7 @@ import safetensors import safetensors.numpy as safe_numpy import re +import glob from diffusers import ( PNDMScheduler, EulerDiscreteScheduler, @@ -17,34 +18,46 @@ "all": [ "--iree-global-opt-propagate-transposes=true", "--iree-opt-const-eval=false", - "--iree-opt-outer-dim-concat=true", - "--iree-vm-target-truncate-unsupported-floats", "--iree-llvmgpu-enable-prefetch=true", - "--iree-opt-data-tiling=false", - "--iree-codegen-gpu-native-math-precision=true", - "--iree-rocm-waves-per-eu=2", - "--iree-flow-inline-constants-max-byte-length=1", + "--iree-execution-model=async-external", ], "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))", ], + "punet": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], + "vae_preprocess": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics{pad-target-type=conv}))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", - "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-opt-outer-dim-concat=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-vm-target-truncate-unsupported-floats", ], "clip": [ "--iree-flow-enable-aggressive-fusion", "--iree-flow-enable-fuse-horizontal-contractions=true", "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-rocm-waves-per-eu=2", + "--iree-codegen-llvmgpu-use-vector-distribution=true", ], "vae": [ "--iree-flow-enable-aggressive-fusion", + "--iree-flow-enable-fuse-horizontal-contractions", + "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-vm-target-truncate-unsupported-floats", ], "winograd": [""], } @@ -66,6 +79,9 @@ "pad_attention": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], + "punet": [ + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics), util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + ], "preprocess_default": [ "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))", ], @@ -140,6 +156,69 @@ def iree_backend_map(device): return iree_device +def replace_with_tk_kernels(tk_kernels_dir, flow_dialect_ir, batch_size): + kernels = glob.glob(tk_kernels_dir + "/bs" + str(batch_size) + "/*") + + # Replace all calls to old kernel with new kernel + print("Inserting kernels and updating calls to kernels...") + kernel_name = {} + for kernel in kernels: + kernel_name[kernel] = kernel.split("/")[-1].split(".")[0] + kernel_map = {} + prefix_map = {} + + base = flow_dialect_ir.split("\n") + new_base = [] + for line in base: + for kernel in kernels: + suffix = kernel.split(".")[0].split("_")[-1] + if "bias" in suffix: + suffix = kernel.split(".")[0].split("_")[-2] + B, M, N, K = suffix.split("x") + old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" + if not old_kernel in line: + continue + if old_kernel in line and "func.func" in line: + num_args = line.count("arg") + with open(kernel, "r") as f: + data = f.readlines() + idx_with_kernel_args = [ + idx for idx, s in enumerate(data) if "func.func" in s + ][0] + kernel_args = data[idx_with_kernel_args].count("arg") + if num_args != kernel_args: + continue + kernel_map[kernel] = line.strip().split(" ")[1][1:-7] + prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1] + if ( + old_kernel in line + and "flow.dispatch" in line + and not "func.func" in line + ): + line = line.replace(kernel_map[kernel], kernel_name[kernel]) + line = line.replace(prefix_map[kernel], kernel_name[kernel]) + new_base.append(line) + # Insert kernels in appropriate locations + final_ir = [] + for line in new_base: + for kernel in kernels: + if ( + prefix_map[kernel] + " {" in line + and "flow.executable" in line + and "private" in line + ): + with open(kernel, "r") as f: + data = f.readlines() + translation_info = data[0].split("#translation = ")[1].strip() + extract = "".join(data[2:-2]) + extract = extract.replace("#translation", translation_info) + final_ir += extract + final_ir.append(line) + + print("tk kernels added") + return final_ir + + def compile_to_vmfb( module_str, device, @@ -153,9 +232,14 @@ def compile_to_vmfb( save_mlir=True, attn_spec=None, winograd=False, - masked_attention=False, + flagset_keywords=[], debug=False, + add_tk_kernels=False, + tk_kernels_dir=None, + batch_size=1, ): + if batch_size != 1 and batch_size != 8: + add_tk_kernels = False flags = [] if mlir_source == "file" and not isinstance(module_str, str): module_str = str(module_str) @@ -204,8 +288,6 @@ def compile_to_vmfb( "--iree-vm-bytecode-module-output-format=flatbuffer-binary", ] ) - if target_triple == "gfx942": - flags.extend(["--iree-rocm-waves-per-eu=2"]) elif device == "cuda": flags.extend( [ @@ -235,15 +317,21 @@ def compile_to_vmfb( elif "vae" in safe_name: flags.extend(MI_flags["vae"]) flags.extend(MI_flags["all"]) - if masked_attention: - flags.extend(GFX11_flags["pad_attention"]) + if "masked_attention" in flagset_keywords: + flags.extend(MI_flags["pad_attention"]) + elif "punet" in flagset_keywords: + flags.extend(MI_flags["punet"]) + elif "vae" in safe_name: + flags.extend(MI_flags["vae_preprocess"]) else: - flags.extend(GFX11_flags["preprocess_default"]) + flags.extend(MI_flags["preprocess_default"]) if "gfx11" in target_triple: flags.extend(GFX11_flags["all"]) - if masked_attention: + if "masked_attention" in flagset_keywords: flags.extend(GFX11_flags["pad_attention"]) + elif "punet" in flagset_keywords: + flags.extend(GFX11_flags["punet"]) else: flags.extend(GFX11_flags["preprocess_default"]) @@ -253,23 +341,22 @@ def compile_to_vmfb( # the TD spec is implemented in C++. if attn_spec in ["default", "mfma", "punet"]: - use_punet = True if attn_spec in ["punet", "i8"] else False - attn_spec = get_mfma_spec_path( - target_triple, - os.path.dirname(safe_name), - masked_attention, - use_punet=use_punet, - ) - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + use_punet = True if attn_spec in ["punet", "i8"] else False + attn_spec = get_mfma_spec_path( + target_triple, + os.path.dirname(safe_name), + use_punet=use_punet, + ) + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): - attn_spec = get_wmma_spec_path( - target_triple, os.path.dirname(safe_name), masked_attention - ) + attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name)) if attn_spec: flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec and attn_spec != "None": - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) for i, flag in enumerate(ireec_flags): k = flag.strip().split("=")[0] @@ -289,6 +376,34 @@ def compile_to_vmfb( for idx, flag in enumerate(flags): if flag is None: flags.pop(idx) + input_ir_type = "torch" + if add_tk_kernels: + print("Adding tk kernels") + flags.extend(["--compile-to=flow"]) + if mlir_source == "file": + flatbuffer_blob = ireec.compile_file( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + elif mlir_source == "str": + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + input_type=input_ir_type, + extra_args=flags, + ) + + flow_ir = flatbuffer_blob.decode("utf-8") + + flow_ir_tk = replace_with_tk_kernels(tk_kernels_dir, flow_ir, batch_size) + module_str = "\n".join(flow_ir_tk) + flags.pop() + flags.extend(["--compile-from=flow"]) + mlir_source = "str" + input_ir_type = "auto" + print("Compiling to", device, "with flags:", flags) # Forces a standard for naming files: @@ -305,7 +420,7 @@ def compile_to_vmfb( flatbuffer_blob = ireec.compile_file( module_str, target_backends=[device], - input_type="torch", + input_type=input_ir_type, extra_args=flags, ) elif mlir_source == "str": @@ -316,7 +431,7 @@ def compile_to_vmfb( flatbuffer_blob = ireec.compile_str( module_str, target_backends=[device], - input_type="torch", + input_type=input_ir_type, extra_args=flags, ) else: @@ -380,14 +495,20 @@ def save_external_weights( external_weights=None, external_weight_file=None, force_format=False, + vae_harness=False, ): if external_weights is not None: if external_weights in ["safetensors", "irpa"]: mod_params = dict(model.named_parameters()) mod_buffers = dict(model.named_buffers()) mod_params.update(mod_buffers) + vae_params = {} for name in mod_params: + if vae_harness: + vae_params[name.replace("vae.", "")] = mod_params[name] mapper["params." + name] = name + if vae_harness: + mod_params = vae_params if external_weight_file and not os.path.isfile(external_weight_file): if not force_format: safetensors.torch.save_file(mod_params, external_weight_file) diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 7ccd12c4..c18fb6da 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -98,6 +98,7 @@ def encode(self, inp): return latent +@torch.no_grad() def export_vae_model( hf_model_name, batch_size, @@ -118,6 +119,7 @@ def export_vae_model( input_mlir=None, weights_only=False, upload_ir=False, + vae_harness=False, ): dtype = torch.float16 if precision == "fp16" else torch.float32 np_dtype = "float16" if precision == "fp16" else "float32" @@ -161,18 +163,25 @@ def export_vae_model( if dtype == torch.float16: vae_model = vae_model.half() mapper = {} - utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path - ) + if (external_weight_path is not None) and ( + not os.path.exists(external_weight_path) + ): + utils.save_external_weights( + mapper, + vae_model, + external_weights, + external_weight_path, + vae_harness=vae_harness, + ) if weights_only: return external_weight_path - input_image_shape = (height, width, 3) + input_image_shape = (batch_size, 3, height, width) input_latents_shape = (batch_size, num_channels, height // 8, width // 8) encode_args = [ torch.empty( input_image_shape, - dtype=torch.float32, + dtype=dtype, ) ] decode_args = [ @@ -195,9 +204,12 @@ def export_vae_model( fxb = FxProgramsBuilder(vae_model) # TODO: fix issues with exporting the encode function. - # @fxb.export_program(args=(encode_args,)) - # def _encode(module, inputs,): - # return module.encode(*inputs) + @fxb.export_program(args=(encode_args,)) + def _encode( + module, + inputs, + ): + return module.encode(*inputs) @fxb.export_program(args=(decode_args,)) def _decode(module, inputs): @@ -205,6 +217,7 @@ def _decode(module, inputs): class CompiledVae(CompiledModule): decode = _decode + encode = _encode if external_weights: externalize_module_parameters(vae_model) @@ -228,6 +241,7 @@ class CompiledVae(CompiledModule): "output_dtypes": [np_dtype], } module = AddMetadataPass(module, model_metadata_decode, "decode").run() + module = AddMetadataPass(module, model_metadata_decode, "encode").run() if compile_to != "vmfb": return str(module) diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index 16602163..81a1735d 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -17,62 +17,30 @@ def run_vae_decode( return results -def run_torch_vae_decode(hf_model_name, variant, example_input): - from diffusers import AutoencoderKL +def run_vae_encode( + device, example_input, vmfb_path, hf_model_name, external_weight_path +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + inputs = [ireert.asdevicearray(runner.config.device, example_input)] - class VaeModel(torch.nn.Module): - def __init__( - self, - hf_model_name, - base_vae=False, - custom_vae="", - low_cpu_mem_usage=False, - hf_auth_token="", - ): - super().__init__() - self.vae = None - if custom_vae == "": - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) - elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) - else: - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - hf_auth_token=hf_auth_token, - ) - self.vae.load_state_dict(custom_vae) - self.base_vae = base_vae - - def decode_inp(self, input): - with torch.no_grad(): - input = 1 / 0.18215 * input - x = self.vae.decode(input, return_dict=False)[0] - return (x / 2 + 0.5).clamp(0, 1) - - def encode_inp(self, inp): - latents = self.vae.encode(inp).latent_dist.sample() - return 0.18215 * latents + results = runner.ctx.modules.compiled_vae["encode"](*inputs).to_host() + + return results + + +def run_torch_vae(hf_model_name, variant, example_input): + from diffusers import AutoencoderKL + from turbine_models.custom_models.sd_inference.vae import VaeModel vae_model = VaeModel( hf_model_name, ) if variant == "decode": - results = vae_model.decode_inp(example_input) + results = vae_model.decode(example_input) elif variant == "encode": - results = vae_model.encode_inp(example_input) + results = vae_model.encode(example_input) np_torch_output = results.detach().cpu().numpy() return np_torch_output diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index 368fb0d7..017244f6 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -369,5 +369,18 @@ def is_valid_file(arg): help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", ) +p.add_argument( + "--add_tk_kernels", + default=False, + action="store_true", + help="Flag to add compiled tk kernels.", +) + +p.add_argument( + "--tk_kernels_dir", + default=False, + action="store_true", + help="Path to directory containing tk kernels sorted by batch size.", +) args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index bd36db76..70a1a043 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -82,7 +82,7 @@ def forward( return noise_pred -def get_punet_model(hf_model_name, external_weight_path, precision="i8"): +def get_punet_model(hf_model_name, external_weight_path, quant_paths, precision="i8"): from sharktank.models.punet.model import ( Unet2DConditionModel as sharktank_unet2d, ClassifierFreeGuidanceUnetModel as sharktank_CFGPunetModel, @@ -90,27 +90,44 @@ def get_punet_model(hf_model_name, external_weight_path, precision="i8"): from sharktank.utils import cli if precision == "i8": - repo_id = "amd-shark/sdxl-quant-models" - subfolder = "unet/int8" - revision = "942e771bf0c2657a8b33380103d04747a75dfa4a" + repo_id = "amd-shark/sdxl-quant-int8" + subfolder = "mi300_all_sym_8_step14_fp32" + revision = "efda8afb35fd72c1769e02370b320b1011622958" elif precision in ["fp16", "fp32"]: repo_id = hf_model_name subfolder = "unet" - revision = "76d28af79639c28a79fa5c6c6468febd3490a37e" + revision = "defeb489fe2bb17b77d587924db9e58048a8c140" def download(filename): return hf_hub_download( repo_id=repo_id, subfolder=subfolder, filename=filename, revision=revision ) - results = { - "config.json": download("config.json"), - "params.safetensors": download("params.safetensors"), - } + if quant_paths and quant_paths["config"] and os.path.exists(quant_paths["config"]): + results = { + "config.json": quant_paths["config"], + } + else: + results = { + "config.json": download("config.json"), + } + + if quant_paths and quant_paths["params"] and os.path.exists(quant_paths["params"]): + results["params.safetensors"] = quant_paths["params"] + else: + results["params.safetensors"] = download("params.safetensors") + output_dir = os.path.dirname(external_weight_path) if precision == "i8": - results["quant_params.json"] = download("quant_params.json") + if ( + quant_paths + and quant_paths["quant_params"] + and os.path.exists(quant_paths["quant_params"]) + ): + results["quant_params.json"] = quant_paths["quant_params"] + else: + results["quant_params.json"] = download("quant_params.json") ds_filename = os.path.basename(external_weight_path) output_path = os.path.join(output_dir, ds_filename) ds = get_punet_dataset( @@ -177,17 +194,21 @@ def export_unet_model( input_mlir=None, weights_only=False, use_punet=False, + quant_paths=None, + add_tk_kernels=False, + tk_kernels_dir=None, ): if use_punet: submodel_name = "punet" else: submodel_name = "unet" - if (not decomp_attn) and use_punet: - attn_spec = "punet" - elif (not decomp_attn) and "gfx9" in target: - attn_spec = "mfma" - elif (not decomp_attn) and "gfx11" in target: - attn_spec = "wmma" + if not attn_spec: + if (not decomp_attn) and use_punet: + attn_spec = "punet" + elif (not decomp_attn) and "gfx9" in target: + attn_spec = "mfma" + elif (not decomp_attn) and "gfx11" in target: + attn_spec = "wmma" safe_name = utils.create_safe_name( hf_model_name, f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}", @@ -198,6 +219,10 @@ def export_unet_model( if decomp_attn == True: ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" + # Currently, only int8 tk kernels are integrated + if add_tk_kernels and precision != "i8": + add_tk_kernels = False + if input_mlir: vmfb_path = utils.compile_to_vmfb( input_mlir, @@ -208,10 +233,15 @@ def export_unet_model( mlir_source="file", return_path=not exit_on_vmfb, attn_spec=attn_spec, + flagset_keywords=["punet"] if use_punet else [], + add_tk_kernels=add_tk_kernels, + tk_kernels_dir=tk_kernels_dir, ) return vmfb_path elif use_punet: - unet_model = get_punet_model(hf_model_name, external_weight_path, precision) + unet_model = get_punet_model( + hf_model_name, external_weight_path, quant_paths, precision + ) else: unet_model = UnetModel(hf_model_name, hf_auth_token, precision) @@ -340,6 +370,10 @@ class CompiledUnet(CompiledModule): safe_name, return_path=True, attn_spec=attn_spec, + flagset_keywords=["punet"] if use_punet else [], + add_tk_kernels=add_tk_kernels, + batch_size=batch_size, + tk_kernels_dir=tk_kernels_dir, ) if exit_on_vmfb: exit() @@ -378,6 +412,8 @@ class CompiledUnet(CompiledModule): args.decomp_attn, attn_spec=args.attn_spec, input_mlir=args.input_mlir, + add_tk_kernels=args.add_tk_kernels, + tk_kernels_dir=args.tk_kernels_dir ) if args.input_mlir: exit() diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py index 753cbb9e..8a02dc19 100644 --- a/models/turbine_models/custom_models/sdxl_inference/vae.py +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -119,9 +119,10 @@ def export_vae_model( mapper = {} - utils.save_external_weights( - mapper, vae_model, external_weights, external_weight_path - ) + if not os.path.exists(external_weight_path): + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_path + ) if weights_only: return external_weight_path diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 674e7d81..98a3cfca 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -208,7 +208,7 @@ def testExportVaeModelDecode(self): current_args["hf_model_name"], current_args["external_weight_path"], ) - torch_output = vae_runner.run_torch_vae_decode( + torch_output = vae_runner.run_torch_vae( current_args["hf_model_name"], "decode", example_input, diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 216b6ff5..060e3a13 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -10,16 +10,12 @@ import shutil from transformers import CLIPTokenizer from turbine_models.custom_models.sd_inference.utils import create_safe_name -from turbine_models.custom_models.sd_inference import schedulers, vae +from turbine_models.custom_models.sd_inference import schedulers, vae, vae_runner from turbine_models.custom_models.sdxl_inference import ( sdxl_prompt_encoder, sdxl_prompt_encoder_runner, unet, unet_runner, - sdxl_scheduled_unet, - sdxl_scheduled_unet_runner, - vae_runner, - sdxl_compiled_pipeline, ) from turbine_models.utils.sdxl_benchmark import run_benchmark import unittest @@ -80,28 +76,46 @@ def command_line_args(request): @pytest.mark.usefixtures("command_line_args") class StableDiffusionXLTest(unittest.TestCase): def setUp(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") + decomp_attn = { + "text_encoder": True, + "unet": False, + "vae": True, + } + self.pipe = SharkSDPipeline( + arguments["hf_model_name"], + arguments["height"], + arguments["width"], + arguments["batch_size"], + arguments["max_length"], + arguments["precision"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags=None, + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", + external_weights_dir="test_weights", + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, + use_i8_punet=False, + ) + self.pipe.prepare_all() - def test01_ExportPromptEncoder(self): + def test01_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( "Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda." ) - arguments["external_weight_path"] = ( - "prompt_encoder." + arguments["external_weights"] - ) - prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( - arguments["hf_model_name"], - hf_auth_token=None, - max_length=arguments["max_length"], - batch_size=arguments["batch_size"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights="safetensors", - external_weight_path=arguments["external_weight_path"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ) + arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] tokenizer_1 = CLIPTokenizer.from_pretrained( arguments["hf_model_name"], subfolder="tokenizer", @@ -126,8 +140,8 @@ def test01_ExportPromptEncoder(self): turbine_output1, turbine_output2, ) = sdxl_prompt_encoder_runner.run_prompt_encoder( - prompt_encoder_vmfb, - arguments["rt_device"], + arguments["vmfb_path"], + self.pipe.map["text_encoder"]["driver"], arguments["external_weight_path"], text_input_ids_list, uncond_input_ids_list, @@ -143,9 +157,9 @@ def test01_ExportPromptEncoder(self): if arguments["benchmark"] or arguments["tracy_profile"]: run_benchmark( "prompt_encoder", - prompt_encoder_vmfb, + arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["text_encoder"]["driver"], max_length=arguments["max_length"], tracy_profile=arguments["tracy_profile"], ) @@ -157,36 +171,10 @@ def test01_ExportPromptEncoder(self): def test02_ExportUnetModel(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") - unet_vmfb = unet.export_unet_model( - hf_model_name=arguments["hf_model_name"], - batch_size=arguments["batch_size"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - max_length=arguments["max_length"], - hf_auth_token=None, - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_unet." - + arguments["external_weights"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - decomp_attn=arguments["decomp_attn"], - attn_spec=arguments["attn_spec"], - exit_on_vmfb=False, - ) - arguments["external_weight_path"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_unet." - + arguments["external_weights"] - ) - arguments["vmfb_path"] = unet_vmfb + + arguments["vmfb_path"] = self.pipe.map["unet"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] + dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 sample = torch.rand( ( @@ -207,7 +195,7 @@ def test02_ExportUnetModel(self): guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) turbine = unet_runner.run_unet( - arguments["rt_device"], + self.pipe.map["unet"]["driver"], sample, timestep, prompt_embeds, @@ -235,7 +223,7 @@ def test02_ExportUnetModel(self): "unet", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["unet"]["driver"], max_length=arguments["max_length"], height=arguments["height"], width=arguments["width"], @@ -252,34 +240,9 @@ def test02_ExportUnetModel(self): def test03_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") - vae_vmfb = vae.export_vae_model( - hf_model_name=arguments["hf_model_name"], - batch_size=arguments["batch_size"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_decode." - + arguments["external_weights"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - decomp_attn=True, - attn_spec=arguments["attn_spec"], - exit_on_vmfb=False, - ) - arguments["external_weight_path"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_decode." - + arguments["external_weights"] - ) - arguments["vmfb_path"] = vae_vmfb + + arguments["vmfb_path"] = self.pipe.map["vae"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["unet"]["weights"] example_input = torch.ones( arguments["batch_size"], 4, @@ -290,20 +253,15 @@ def test03_ExportVaeModelDecode(self): example_input_torch = example_input if arguments["precision"] == "fp16": example_input = example_input.half() - turbine = vae_runner.run_vae( - arguments["rt_device"], + turbine = vae_runner.run_vae_decode( + self.pipe.map["vae"]["driver"], example_input, arguments["vmfb_path"], arguments["hf_model_name"], - arguments["external_weight_path"], + self.pipe.map["vae"]["weights"], ) torch_output = vae_runner.run_torch_vae( arguments["hf_model_name"], - ( - "madebyollin/sdxl-vae-fp16-fix" - if arguments["precision"] == "fp16" - else "" - ), "decode", example_input_torch, ) @@ -312,7 +270,7 @@ def test03_ExportVaeModelDecode(self): "vae_decode", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["vae"]["driver"], height=arguments["height"], width=arguments["width"], precision=arguments["precision"], @@ -323,40 +281,14 @@ def test03_ExportVaeModelDecode(self): np.testing.assert_allclose(torch_output, turbine, rtol, atol) + @pytest.mark.xfail(reason="NaN output on rocm, needs triage and file") def test04_ExportVaeModelEncode(self): if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]: self.skipTest( "Compilation error on cpu, vulkan and rocm; To be tested on cuda." ) - vae_vmfb = vae.export_vae_model( - vae_model=self.vae_model, - # This is a public model, so no auth required - hf_model_name=arguments["hf_model_name"], - batch_size=arguments["batch_size"], - height=arguments["height"], - width=arguments["width"], - precision=arguments["precision"], - compile_to="vmfb", - external_weights=arguments["external_weights"], - external_weight_path=self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_encode." - + arguments["external_weights"], - device=arguments["device"], - target=arguments["iree_target_triple"], - ireec_flags=arguments["ireec_flags"], - decomp_attn=True, - exit_on_vmfb=True, - ) - arguments["external_weight_path"] = ( - self.safe_model_name - + "_" - + arguments["precision"] - + "_vae_encode." - + arguments["external_weights"] - ) - arguments["vmfb_path"] = vae_vmfb + arguments["vmfb_path"] = self.pipe.map["vae"]["vmfb"] + arguments["external_weight_path"] = self.pipe.map["vae"]["weights"] example_input = torch.ones( arguments["batch_size"], 3, @@ -367,20 +299,15 @@ def test04_ExportVaeModelEncode(self): example_input_torch = example_input if arguments["precision"] == "fp16": example_input = example_input.half() - turbine = vae_runner.run_vae( - arguments["rt_device"], + turbine = vae_runner.run_vae_encode( + self.pipe.map["vae"]["driver"], example_input, arguments["vmfb_path"], arguments["hf_model_name"], - arguments["external_weight_path"], + self.pipe.map["vae"]["weights"], ) torch_output = vae_runner.run_torch_vae( arguments["hf_model_name"], - ( - "madebyollin/sdxl-vae-fp16-fix" - if arguments["precision"] == "fp16" - else "" - ), "encode", example_input_torch, ) @@ -388,8 +315,8 @@ def test04_ExportVaeModelEncode(self): run_benchmark( "vae_encode", arguments["vmfb_path"], - arguments["external_weight_path"], - arguments["rt_device"], + self.pipe.map["vae"]["weights"], + self.pipe.map["vae"]["driver"], height=arguments["height"], width=arguments["width"], precision=arguments["precision"], @@ -402,39 +329,9 @@ def test04_ExportVaeModelEncode(self): def test05_t2i_generate_images(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Have issues with submodels on vulkan, cuda") - from turbine_models.custom_models.sd_inference.sd_pipeline import ( - SharkSDPipeline, - ) - decomp_attn = { - "text_encoder": False, - "unet": False, - "vae": True, - } - sd_pipe = SharkSDPipeline( - arguments["hf_model_name"], - arguments["height"], - arguments["width"], - arguments["batch_size"], - arguments["max_length"], - arguments["precision"], - arguments["device"], - arguments["iree_target_triple"], - ireec_flags=None, # ireec_flags - attn_spec=arguments["attn_spec"], - decomp_attn=decomp_attn, - pipeline_dir="test_vmfbs", # pipeline_dir - external_weights_dir="test_weights", # external_weights_dir - external_weights=arguments["external_weights"], - num_inference_steps=arguments["num_inference_steps"], - cpu_scheduling=True, - scheduler_id=arguments["scheduler_id"], - shift=None, # shift - use_i8_punet=False, - ) - sd_pipe.prepare_all() - sd_pipe.load_map() - output = sd_pipe.generate_images( + self.pipe.load_map() + output = self.pipe.generate_images( arguments["prompt"], arguments["negative_prompt"], arguments["num_inference_steps"], @@ -447,45 +344,19 @@ def test05_t2i_generate_images(self): ) assert output is not None - @pytest.mark.skip(reason="Needs sdxl_quantized branch of IREE") + @pytest.mark.xfail(reason="compilation issue on gfx90a") def test06_t2i_generate_images_punet(self): if arguments["device"] in ["vulkan", "cuda", "rocm"]: self.skipTest( "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." ) - from turbine_models.custom_models.sd_inference.sd_pipeline import ( - SharkSDPipeline, - ) - - decomp_attn = { - "text_encoder": False, - "unet": False, - "vae": True, - } - sd_pipe = SharkSDPipeline( - arguments["hf_model_name"], - arguments["height"], - arguments["width"], - arguments["batch_size"], - arguments["max_length"], - arguments["precision"], - arguments["device"], - arguments["iree_target_triple"], - ireec_flags=None, # ireec_flags - attn_spec=arguments["attn_spec"], - decomp_attn=decomp_attn, - pipeline_dir="test_vmfbs", # pipeline_dir - external_weights_dir="test_weights", # external_weights_dir - external_weights=arguments["external_weights"], - num_inference_steps=arguments["num_inference_steps"], - cpu_scheduling=True, - scheduler_id=arguments["scheduler_id"], - shift=None, # shift - use_i8_punet=True, - ) - sd_pipe.prepare_all() - sd_pipe.load_map() - output = sd_pipe.generate_images( + self.pipe.unload_submodel("unet") + self.pipe.use_punet = True + self.pipe.use_i8_punet = True + self.pipe.setup_punet() + self.pipe.prepare_all() + self.pipe.load_map() + output = self.pipe.generate_images( arguments["prompt"], arguments["negative_prompt"], arguments["num_inference_steps"],