Skip to content

Commit

Permalink
Small fixes to SDXL inference pipeline/exports/compile
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Sep 24, 2024
1 parent dbc7635 commit a4a6801
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 7 deletions.
2 changes: 2 additions & 0 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def __init__(
target, dict
), "Device and target triple must be both dicts or both strings."
for submodel in self.map.keys():
if self.map[submodel].get("load") == False:
continue
assert submodel in device.keys(), f"Device for {submodel} not found."
assert (
submodel in target.keys()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
"decomp_attn": None,
},
},
}
sdxl_compiled_pipeline_map = {
"unetloop": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
Expand Down Expand Up @@ -434,7 +436,7 @@ def load_scheduler(
if self.is_sd3:
export_fn = sd3_schedulers.export_scheduler_model
else:
export_fn = scheduler.export_scheduler_model
export_fn = schedulers.export_scheduler_model
self.map["scheduler"] = {
"module_name": "compiled_scheduler",
"export_fn": export_fn,
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=
url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir"
elif not masked_attention:
suffix = ""
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir"
url = "https://raw.githubusercontent.com/iree-org/iree/refs/heads/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
else:
suffix = "_pad"
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"
Expand Down
1 change: 0 additions & 1 deletion models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def export_vae_model(
vae_model,
external_weights,
external_weight_path,
vae_harness=vae_harness,
)
if weights_only:
return external_weight_path
Expand Down
4 changes: 0 additions & 4 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ def export_unet_model(
if not attn_spec:
if (not decomp_attn) and use_punet:
attn_spec = "punet"
elif (not decomp_attn) and "gfx9" in target:
attn_spec = "mfma"
elif (not decomp_attn) and "gfx11" in target:
attn_spec = "wmma"
safe_name = utils.create_safe_name(
hf_model_name,
f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}",
Expand Down

0 comments on commit a4a6801

Please sign in to comment.