Skip to content

Commit

Permalink
Integration of tk kernels into pipeline (#789)
Browse files Browse the repository at this point in the history
Currently using a link, but Nithin will be pushing the fix to use a file
name asap
  • Loading branch information
saienduri authored Jul 24, 2024
1 parent 25b2462 commit 0e57b4e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def __init__(
punet_quant_paths: dict[str] = None,
vae_weight_path: str = None,
vae_harness: bool = False,
add_tk_kernels: bool = False,
):
common_export_args = {
"hf_model_name": None,
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
self.scheduler = None

self.split_scheduler = True
self.add_tk_kernels = add_tk_kernels

self.base_model_name = (
hf_model_name
Expand Down Expand Up @@ -367,6 +369,8 @@ def __init__(

def setup_punet(self):
if self.use_i8_punet:
if self.add_tk_kernels:
self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels
self.map["unet"]["export_args"]["precision"] = "i8"
self.map["unet"]["export_args"]["external_weight_path"] = (
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
Expand Down
20 changes: 13 additions & 7 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,15 @@ def iree_backend_map(device):
return iree_device


def replace_with_tk_kernels(
flow_dialect_ir,
):
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir"
]
def replace_with_tk_kernels(flow_dialect_ir, batch_size):
if batch_size == 8:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir"
]
if batch_size == 1:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir"
]

# Replace all calls to old kernel with new kernel
print("Inserting kernels and updating calls to kernels...")
Expand Down Expand Up @@ -235,7 +238,10 @@ def compile_to_vmfb(
flagset_keywords=[],
debug=False,
add_tk_kernels=False,
batch_size=1,
):
if batch_size != 1 and batch_size != 8:
add_tk_kernels = False
flags = []
if mlir_source == "file" and not isinstance(module_str, str):
module_str = str(module_str)
Expand Down Expand Up @@ -393,7 +399,7 @@ def compile_to_vmfb(

flow_ir = flatbuffer_blob.decode("utf-8")

flow_ir_tk = replace_with_tk_kernels(flow_ir)
flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size)
module_str = "\n".join(flow_ir_tk)
flags.pop()
flags.extend(["--compile-from=flow"])
Expand Down
1 change: 1 addition & 0 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ class CompiledUnet(CompiledModule):
attn_spec=attn_spec,
flagset_keywords=["punet"] if use_punet else [],
add_tk_kernels=add_tk_kernels,
batch_size=batch_size,
)
if exit_on_vmfb:
exit()
Expand Down

0 comments on commit 0e57b4e

Please sign in to comment.