You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Error code: 1
Diagnostics:
iree-compile: Too many positional arguments specified!
Can specify at most 1 positional arguments: See: /opt/conda/envs/turb/lib/python3.11/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile --help
Error code: 1
Diagnostics:
iree-compile: Too many positional arguments specified!
Can specify at most 1 positional arguments: See: /opt/conda/envs/turb/lib/python3.11/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile --help
Invoked with:
iree-compile /opt/conda/envs/turb/lib/python3.11/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false iree-hal-target-backends=rocm --iree-hip-target=gfx90a --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-dispatch-creation-enable-aggressive-fusion --iree-dispatch-creation-enable-fuse-horizontal-contractions --iree-opt-aggressively-propagate-transposes=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))
`import torch
from diffusers import UNet2DConditionModel
from shark_turbine.aot import *
import iree
class UnetModel(torch.nn.Module):
def init(self):
super().init()
self.unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet",
low_cpu_mem_usage=False,)
def forward(
self,latent_model_input, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale,):
added_cond_kwargs={"text_embeds":text_embeds,"time_ids":time_ids}
noise_pred=self.unet.forward(
latent_model_input, timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=None,added_cond_kwargs=added_cond_kwargs,return_dict=False,)[0]
return noise_pred
dtype = torch.float16
model = UnetModel().to(dtype=dtype)
example_forward_args=[
torch.empty((1,4,128,128),dtype=dtype),
torch.empty(1,dtype=torch.long),
torch.empty((1,64,2048),dtype=dtype),
torch.empty((1,1280), dtype=dtype),
torch.empty((1,6), dtype=dtype),
torch.tensor([7.5], dtype=dtype),]
fxb = FxProgramsBuilder(model)
@fxb.export_program(args=((example_forward_args),))
def _forward(model, inputs):
return model.forward(*inputs)
class CompiledUnet(CompiledModule):
run_forward=_forward
inst = CompiledUnet(context=iree.compiler.ir.Context(), import_to="IMPORT")
module = CompiledModule.get_mlir_module(inst)
rocm_flags = ["iree-hal-target-backends=rocm",
"--iree-hip-target=gfx90a",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-dispatch-creation-enable-aggressive-fusion",
"--iree-dispatch-creation-enable-fuse-horizontal-contractions",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-opt-data-tiling=false",
"--iree-codegen-gpu-native-math-precision=true",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-global-opt-propagate-transposes=true",
"--iree-opt-const-eval=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-execution-model=async-external",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))",]
compiled_bin=iree.compiler.compile_str(str(module), target_backends=["rocm"], input_type="torch", extra_args=rocm_flags)
with open(f"unet.vmfb","wb+") as f:
f.write(compiled_bin)
config = iree.runtime.Config("hip")
rt_module=iree.runtime.VmModule.mmap(config.vm_instance,"unet.vmfb")
rt=iree.runtime.create_hal_module(config.vm_instance, config.device)
vm_modules=[rt_module, rt]
ctx = iree.runtime.SystemContext(vm_modules=vm_modules, config=config)
unet_output = ctx.modules.compiled_unet"run_forward".to_host()
print(net_output)`
The text was updated successfully, but these errors were encountered: