Skip to content

Commit

Permalink
Fix the compiled pipeline compilation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd committed Jul 29, 2024
1 parent c1e9195 commit eddee10
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 27 deletions.
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def export_submodel(
self.map[submodel]["export_args"]["batch_size"],
self.map[submodel]["export_args"]["max_length"],
"produce_img_split",
unet_module_name = self.map["unet"]["module_name"],
unet_module_name=self.map["unet"]["module_name"],
)
dims = [
self.map[submodel]["export_args"]["width"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,7 @@ def generate_images(
for i in range(batch_count):
if self.compiled_pipeline:
image = produce_images_compiled(
samples[i],
prompt_embeds,
negative_embeds,
guidance_scale
samples[i], prompt_embeds, negative_embeds, guidance_scale
)
else:
produce_latents_input = [
Expand Down
6 changes: 3 additions & 3 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
],
"unet": [
"--iree-flow-enable-aggressive-fusion",
"--iree-flow-enable-fuse-horizontal-contractions=true",
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
],
"clip": [
"--iree-flow-enable-aggressive-fusion",
"--iree-flow-enable-fuse-horizontal-contractions=true",
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
"--iree-opt-aggressively-propagate-transposes=true",
],
"vae": [
Expand All @@ -61,7 +61,7 @@
"--iree-opt-const-eval=false",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-flow-enable-aggressive-fusion",
"--iree-flow-enable-fuse-horizontal-contractions=true",
"--iree-global-opt-enable-fuse-horizontal-contractions=true",
"--iree-codegen-gpu-native-math-precision=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
Expand Down
31 changes: 15 additions & 16 deletions models/turbine_models/custom_models/sdxl_inference/pipeline_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,27 @@

produce_img_split = r"""
module @sdxl_compiled_pipeline {{
func.func private @{scheduler_module}.run_initialize(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],f16>, !torch.vtensor<[{num_steps}],f32>) attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{scheduler_module}.run_scale(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{scheduler_module}.run_step(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],{precision}>, %arg2: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{unet_module}.{unet_function}(%arg0: !torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, %arg1: !torch.vtensor<[1],{precision}>, %arg2: !torch.vtensor<[{bd},{max_length},2048],{precision}>, %arg3: !torch.vtensor<[{bd},1280],{precision}>, %arg4: !torch.vtensor<[{bd},6],{precision}>, %arg5: !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{vae_module}.decode(%arg0: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{scheduler_module}.run_initialize(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1xf16>, tensor<{num_steps}xf32>) attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{scheduler_module}.run_scale(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1xi64>, %arg2: tensor<{num_steps}xf32>) -> (tensor<{bd}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>) attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{scheduler_module}.run_step(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{unet_module}.{unet_function}(%arg0: tensor<{bd}x4x{lh}x{lw}x{precision}>, %arg1: tensor<1x{precision}>, %arg2: tensor<{bd}x{max_length}x2048x{precision}>, %arg3: tensor<{bd}x1280x{precision}>, %arg4: tensor<{bd}x6x{precision}>, %arg5: tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
func.func private @{vae_module}.decode(%arg0: tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> attributes {{torch.assume_strict_symbolic_shapes}}
func.func @produce_image_latents(%sample: !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, %p_embeds: !torch.vtensor<[{bd},{max_length},2048],{precision}>, %t_embeds: !torch.vtensor<[{bd},1280],{precision}>, %guidance_scale: !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}> {{
%noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{num_steps}],f32>)
func.func @produce_image_latents(%sample: tensor<{batch_size}x4x{lh}x{lw}x{precision}>, %p_embeds: tensor<{bd}x{max_length}x2048x{precision}>, %t_embeds: tensor<{bd}x1280x{precision}>, %guidance_scale: tensor<1x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}> {{
%noisy_sample, %time_ids, %delete, %timesteps = func.call @{scheduler_module}.run_initialize(%sample) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>, tensor<{num_steps}xf32>)
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%n_steps = arith.constant {num_steps} : index
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) {{
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) {{
%step_64 = arith.index_cast %arg0 : index to i64
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
%step_torch = torch_c.from_builtin_tensor %this_step : tensor<1xi64> -> !torch.vtensor<[1],si64>
%scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %step_torch, %timesteps) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],si64>, !torch.vtensor<[{num_steps}],f32>) -> (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>)
%inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (!torch.vtensor<[{bd},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{bd},{max_length},2048],{precision}>, !torch.vtensor<[{bd},1280],{precision}>, !torch.vtensor<[{bd},6],{precision}>, !torch.vtensor<[1],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
%pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>, !torch.vtensor<[1],{precision}>, !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
scf.yield %pred : !torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>
%scaled, %timestep = func.call @{scheduler_module}.run_scale(%arg, %this_step, %timesteps) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1xi64>, tensor<{num_steps}xf32>) -> (tensor<{bd}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>)
%inner = func.call @{unet_module}.{unet_function}(%scaled, %timestep, %p_embeds, %t_embeds, %time_ids, %guidance_scale) : (tensor<{bd}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{bd}x{max_length}x2048x{precision}>, tensor<{bd}x1280x{precision}>, tensor<{bd}x6x{precision}>, tensor<1x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
%pred = func.call @{scheduler_module}.run_step(%inner, %timestep, %arg) : (tensor<{batch_size}x4x{lh}x{lw}x{precision}>, tensor<1x{precision}>, tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x4x{lh}x{lw}x{precision}>
scf.yield %pred : tensor<{batch_size}x4x{lh}x{lw}x{precision}>
}}
%image = func.call @{vae_module}.decode(%res): (!torch.vtensor<[{batch_size},4,{lh},{lw}],{precision}>) -> !torch.vtensor<[{batch_size},3,{height},{width}],{precision}>
return %image : !torch.vtensor<[{batch_size},3,{height},{width}],{precision}>
%image = func.call @{vae_module}.decode(%res): (tensor<{batch_size}x4x{lh}x{lw}x{precision}>) -> tensor<{batch_size}x3x{height}x{width}x{precision}>
return %image : tensor<{batch_size}x3x{height}x{width}x{precision}>
}}
}}
"""
Expand Down Expand Up @@ -128,4 +127,4 @@ def get_pipeline_ir(
scheduler_module=scheduler_module_name,
vae_module=vae_module_name,
num_steps=num_steps,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def export_submodel(
self.hf_model_name,
None,
self.max_length,
self.batch_size,
self.precision,
"vmfb",
self.external_weights,
Expand All @@ -494,7 +495,6 @@ def export_submodel(
input_mlir=input_mlir["prompt_encoder"],
attn_spec=self.attn_spec,
weights_only=weights_only,
batchsize=self.batch_size,
batch_input=self.batch_prompt_input,
)
return prompt_encoder_vmfb, prompt_encoder_external_weight_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def encode_prompts_turbo(
module_str = str(module)

if compile_to != "vmfb":
return module_str
return module_str, None
else:
vmfb_path = utils.compile_to_vmfb(
module_str,
Expand All @@ -289,7 +289,7 @@ def encode_prompts_turbo(
const_expr_hoisting=True,
attn_spec=attn_spec,
)
return vmfb_path
return None, vmfb_path


if __name__ == "__main__":
Expand Down

0 comments on commit eddee10

Please sign in to comment.