Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Jul 25, 2024
1 parent 5876436 commit 51ef1db
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size):
if batch_size == 1:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir",
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir"
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir",
]

# Replace all calls to old kernel with new kernel
Expand Down Expand Up @@ -192,8 +192,10 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size):
if old_kernel in line and "func.func" in line:
data = urlopen(kernel).read().decode("utf-8")
data = data.split("\n")
idx_with_kernel_args = [idx for idx, s in enumerate(data) if 'func.func' in s][0]
kernel_args = data[idx_with_kernel_args].count('arg')
idx_with_kernel_args = [
idx for idx, s in enumerate(data) if "func.func" in s
][0]
kernel_args = data[idx_with_kernel_args].count("arg")
num_args = line.count("arg")
if num_args != kernel_args:
continue
Expand Down Expand Up @@ -552,11 +554,11 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
# schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained(
# model_id,
Expand Down

0 comments on commit 51ef1db

Please sign in to comment.