From 6ca109a9c96e5a038decf3fb7ebfe994572d5e78 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 10 Sep 2024 23:15:14 -0500 Subject: [PATCH] Add attention decomposition mechanism to sdxl clip exports. --- .../custom_models/sdxl_inference/clip.py | 50 ++++++++----- .../sdxl_inference/sdxl_prompt_encoder.py | 74 +++++++++++-------- 2 files changed, 74 insertions(+), 50 deletions(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/clip.py b/models/turbine_models/custom_models/sdxl_inference/clip.py index 2740745ed..269e87d57 100644 --- a/models/turbine_models/custom_models/sdxl_inference/clip.py +++ b/models/turbine_models/custom_models/sdxl_inference/clip.py @@ -62,6 +62,7 @@ def export_clip_model( input_mlir=None, attn_spec=None, weights_only=False, + decomp_attn=True, ): if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, "clip_" + str(index)) @@ -118,25 +119,36 @@ def export_clip_model( if weights_only: return weights_path - - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - text_encoder_model, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(text_encoder_model) - - def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)): - return jittable(text_encoder_model.forward)(inp) - - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) - - module_str = str(CompiledModule.get_mlir_module(inst)) + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)): + return jittable(text_encoder_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) if compile_to != "vmfb": return module_str, tokenizer diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index d547cadf7..3b9fb102f 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -171,7 +171,7 @@ def export_prompt_encoder( attn_spec=None, weights_only=False, batch_input=False, - decomp_attn=False, # Compatibility + decomp_attn=True, ): do_classifier_free_guidance = True @@ -233,39 +233,51 @@ def export_prompt_encoder( if weights_only: return None, external_weight_path - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - prompt_encoder_module, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(prompt_encoder_module) - - def encode_prompts( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - ): - return jittable(prompt_encoder_module.forward)( - t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 - ) + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + prompt_encoder_module, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(prompt_encoder_module) + + def encode_prompts( + self, + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + ): + return jittable(prompt_encoder_module.forward)( + t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 + ) - def encode_prompts_turbo( - self, - t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), - ): - return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) + def encode_prompts_turbo( + self, + t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64), + ): + return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2) - import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = CompiledClip(context=Context(), import_to=import_to) + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) - module = CompiledModule.get_mlir_module(inst) + module = CompiledModule.get_mlir_module(inst) model_metadata_encode = { "model_name": hf_model_name + "_text_encoder",