Skip to content

Commit

Permalink
Add attention decomposition mechanism to sdxl clip exports.
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Sep 11, 2024
1 parent 35517d9 commit 6ca109a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 50 deletions.
50 changes: 31 additions & 19 deletions models/turbine_models/custom_models/sdxl_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def export_clip_model(
input_mlir=None,
attn_spec=None,
weights_only=False,
decomp_attn=True,
):
if pipeline_dir not in [None, ""]:
safe_name = os.path.join(pipeline_dir, "clip_" + str(index))
Expand Down Expand Up @@ -118,25 +119,36 @@ def export_clip_model(

if weights_only:
return weights_path

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
decomp_list = []
if decomp_attn == True:
decomp_list = [
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.scaled_dot_product_attention,
]
with decompositions.extend_aot_decompositions(
from_current=True,
add_ops=decomp_list,
):

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))

if compile_to != "vmfb":
return module_str, tokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def export_prompt_encoder(
attn_spec=None,
weights_only=False,
batch_input=False,
decomp_attn=False, # Compatibility
decomp_attn=True,
):
do_classifier_free_guidance = True

Expand Down Expand Up @@ -233,39 +233,51 @@ def export_prompt_encoder(
if weights_only:
return None, external_weight_path

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
prompt_encoder_module,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(prompt_encoder_module)

def encode_prompts(
self,
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward)(
t_ids_1, t_ids_2, uc_ids_1, uc_ids_2
)
decomp_list = []
if decomp_attn == True:
decomp_list = [
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.scaled_dot_product_attention,
]
with decompositions.extend_aot_decompositions(
from_current=True,
add_ops=decomp_list,
):

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
prompt_encoder_module,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(prompt_encoder_module)

def encode_prompts(
self,
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward)(
t_ids_1, t_ids_2, uc_ids_1, uc_ids_2
)

def encode_prompts_turbo(
self,
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2)
def encode_prompts_turbo(
self,
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
):
return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)

module = CompiledModule.get_mlir_module(inst)
module = CompiledModule.get_mlir_module(inst)

model_metadata_encode = {
"model_name": hf_model_name + "_text_encoder",
Expand Down

0 comments on commit 6ca109a

Please sign in to comment.