Skip to content

Commit

Permalink
Integrate int8 tk kernels (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah authored Jul 23, 2024
1 parent d857f77 commit 37548f2
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 2 deletions.
97 changes: 95 additions & 2 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,70 @@ def iree_backend_map(device):
return iree_device


def replace_with_tk_kernels(
flow_dialect_ir,
):
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir"
]

# Replace all calls to old kernel with new kernel
print("Inserting kernels and updating calls to kernels...")
kernel_name = {}
for kernel in kernels:
kernel_name[kernel] = kernel.split("/")[-1].split(".")[0]
kernel_map = {}
prefix_map = {}

base = flow_dialect_ir.split("\n")
new_base = []
for line in base:
for kernel in kernels:
suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1]
bias_explicit = False
if "bias" in suffix:
bias_explicit = True
kernel_args = 3 + int(suffix[4:])
suffix = kernel.split(".")[0].split("_")[-2]
B, M, N, K = suffix.split("x")
old_kernel = f"matmul_like_{B}x{M}x{N}x{K}"
if not old_kernel in line:
continue
if old_kernel in line and "func.func" in line:
if bias_explicit:
num_args = line.count("arg")
if num_args != kernel_args:
continue
kernel_map[kernel] = line.strip().split(" ")[1][1:-7]
prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1]
if (
old_kernel in line
and "flow.dispatch" in line
and not "func.func" in line
):
line = line.replace(kernel_map[kernel], kernel_name[kernel])
line = line.replace(prefix_map[kernel], kernel_name[kernel])
new_base.append(line)
# Insert kernels in appropriate locations
final_ir = []
for line in new_base:
for kernel in kernels:
if (
prefix_map[kernel] + " {" in line
and "flow.executable" in line
and "private" in line
):
data = urlopen(kernel).read().decode("utf-8")
data = data.split("\n")
translation_info = data[0].split("#translation = ")[1].strip()
data[10] = data[10].replace("#translation", translation_info)
final_ir.append("\n".join(data[2:-3]))
final_ir.append(line)

print("tk kernels added")
return final_ir


def compile_to_vmfb(
module_str,
device,
Expand All @@ -170,6 +234,7 @@ def compile_to_vmfb(
winograd=False,
flagset_keywords=[],
debug=False,
add_tk_kernels=False,
):
flags = []
if mlir_source == "file" and not isinstance(module_str, str):
Expand Down Expand Up @@ -307,6 +372,34 @@ def compile_to_vmfb(
for idx, flag in enumerate(flags):
if flag is None:
flags.pop(idx)
input_ir_type = "torch"
if add_tk_kernels:
print("Adding tk kernels")
flags.extend(["--compile-to=flow"])
if mlir_source == "file":
flatbuffer_blob = ireec.compile_file(
module_str,
target_backends=[device],
input_type=input_ir_type,
extra_args=flags,
)
elif mlir_source == "str":
flatbuffer_blob = ireec.compile_str(
module_str,
target_backends=[device],
input_type=input_ir_type,
extra_args=flags,
)

flow_ir = flatbuffer_blob.decode("utf-8")

flow_ir_tk = replace_with_tk_kernels(flow_ir)
module_str = "\n".join(flow_ir_tk)
flags.pop()
flags.extend(["--compile-from=flow"])
mlir_source = "str"
input_ir_type = "auto"

print("Compiling to", device, "with flags:", flags)

# Forces a standard for naming files:
Expand All @@ -323,7 +416,7 @@ def compile_to_vmfb(
flatbuffer_blob = ireec.compile_file(
module_str,
target_backends=[device],
input_type="torch",
input_type=input_ir_type,
extra_args=flags,
)
elif mlir_source == "str":
Expand All @@ -334,7 +427,7 @@ def compile_to_vmfb(
flatbuffer_blob = ireec.compile_str(
module_str,
target_backends=[device],
input_type="torch",
input_type=input_ir_type,
extra_args=flags,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,5 +369,11 @@ def is_valid_file(arg):
help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py",
)

p.add_argument(
"--add_tk_kernels",
default=False,
action="store_true",
help="Flag to add compiled tk kernels.",
)

args, unknown = p.parse_known_args()
8 changes: 8 additions & 0 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def export_unet_model(
weights_only=False,
use_punet=False,
quant_paths=None,
add_tk_kernels=False,
):
if use_punet:
submodel_name = "punet"
Expand All @@ -217,6 +218,10 @@ def export_unet_model(
if decomp_attn == True:
ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False"

# Currently, only int8 tk kernels are integrated
if add_tk_kernels and precision != "i8":
add_tk_kernels = False

if input_mlir:
vmfb_path = utils.compile_to_vmfb(
input_mlir,
Expand All @@ -228,6 +233,7 @@ def export_unet_model(
return_path=not exit_on_vmfb,
attn_spec=attn_spec,
flagset_keywords=["punet"] if use_punet else [],
add_tk_kernels=add_tk_kernels,
)
return vmfb_path
elif use_punet:
Expand Down Expand Up @@ -363,6 +369,7 @@ class CompiledUnet(CompiledModule):
return_path=True,
attn_spec=attn_spec,
flagset_keywords=["punet"] if use_punet else [],
add_tk_kernels=add_tk_kernels,
)
if exit_on_vmfb:
exit()
Expand Down Expand Up @@ -401,6 +408,7 @@ class CompiledUnet(CompiledModule):
args.decomp_attn,
attn_spec=args.attn_spec,
input_mlir=args.input_mlir,
add_tk_kernels=args.add_tk_kernels,
)
if args.input_mlir:
exit()
Expand Down

0 comments on commit 37548f2

Please sign in to comment.