-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upstreaming MLPerf punet changes, server/harness support. #799
base: main
Are you sure you want to change the base?
Changes from 53 commits
a3376d9
7cabac0
7dfd4c8
1cd3ee9
1a90abd
b70318d
2d7ebcd
feebc87
af7782b
eff59a9
63fb053
aff48ab
b10ad8d
321d21d
2de912e
9fdc07f
b20be32
02705a9
9229aed
f09ef4a
bbcc424
1f19c7f
39c0c00
3c59b25
2e9de46
aa0ac2b
6556a36
2c49cb6
cb911b1
d857f77
37548f2
25b2462
0e57b4e
920dbf5
6f16731
15dbd93
0c02652
3fd954b
7f8a2b0
bf63aec
a74d98e
925cd0c
7715fd0
e276c78
de5d3de
d0d3ae6
403fe47
0ac6b64
1a41394
493f260
711403c
e554da8
cdd2f66
d23a45b
7ecfece
2d7a92e
18bffdb
df85dca
674128e
4d6198b
f3e3fe3
2ed8037
7adfc7a
ff2c3c9
afdb8d6
a4e67e8
35517d9
6ca109a
453fb38
c0be575
e3cd69d
e3e1dcb
56d6ee7
d330564
ffba3ea
fad7e6e
05fa32d
0291d43
e337f2a
0fd8ad0
e1c4ac2
61bb4ef
dfb9474
9fe20a6
f39b2d2
d3c8e80
40808db
e630d39
fc6d018
7d50dc8
f140926
67e6558
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from diffusers import StableDiffusion3Pipeline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we integrate this with a test to output the image numerics we can compare against? I know you saw some significant different numerics between cpu and different gpu backends where this may be difficult to directly compare, maybe some FID/CLIP scores? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing a faithful comparison with diffusers reference is a larger problem -- we are really best off investing in getting real CLIP/FID scores with a validation dataset. This diffusers reference is really just a hold-over/sanity check for now; I don't even trust it to give us a decent baseline from CPU. We can leave this out for now but I'd rather keep it just to have something ready for comparison with diffusers on ROCM/CUDA |
||
import torch | ||
from datetime import datetime as dt | ||
|
||
|
||
def run_diffusers_cpu( | ||
hf_model_name, | ||
prompt, | ||
negative_prompt, | ||
guidance_scale, | ||
seed, | ||
height, | ||
width, | ||
num_inference_steps, | ||
): | ||
from diffusers import StableDiffusion3Pipeline | ||
|
||
pipe = StableDiffusion3Pipeline.from_pretrained( | ||
hf_model_name, torch_dtype=torch.float32 | ||
) | ||
pipe = pipe.to("cpu") | ||
generator = torch.Generator().manual_seed(int(seed)) | ||
|
||
image = pipe( | ||
prompt=prompt, | ||
negative_prompt=negative_prompt, | ||
num_inference_steps=num_inference_steps, | ||
guidance_scale=guidance_scale, | ||
height=height, | ||
width=width, | ||
generator=generator, | ||
).images[0] | ||
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") | ||
image.save(f"diffusers_reference_output_{timestamp}.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
from turbine_models.custom_models.sd_inference.sd_cmd_opts import args | ||
|
||
run_diffusers_cpu( | ||
args.hf_model_name, | ||
args.prompt, | ||
args.negative_prompt, | ||
args.guidance_scale, | ||
args.seed, | ||
args.height, | ||
args.width, | ||
args.num_inference_steps, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ def forward( | |
pooled_projections, | ||
timestep, | ||
): | ||
timestep.expand(hidden_states.shape[0]) | ||
noise_pred = self.mmdit( | ||
hidden_states, | ||
encoder_hidden_states, | ||
|
@@ -71,7 +72,7 @@ def forward(self, q, k, v): | |
def export_attn( | ||
precision="fp16", | ||
device="cpu", | ||
target_triple="x86_64-unknown-linux-gnu", | ||
target="x86_64-unknown-linux-gnu", | ||
ireec_flags="", | ||
compile_to="torch", | ||
decomp_attn=False, | ||
|
@@ -128,7 +129,7 @@ class CompiledAttn(CompiledModule): | |
vmfb_path = utils.compile_to_vmfb( | ||
module_str, | ||
device, | ||
target_triple, | ||
target, | ||
ireec_flags, | ||
safe_name, | ||
return_path=True, | ||
|
@@ -139,7 +140,6 @@ class CompiledAttn(CompiledModule): | |
|
||
@torch.no_grad() | ||
def export_mmdit_model( | ||
mmdit_model, | ||
hf_model_name, | ||
batch_size, | ||
height, | ||
|
@@ -151,8 +151,8 @@ def export_mmdit_model( | |
external_weights=None, | ||
external_weight_path=None, | ||
device=None, | ||
target_triple=None, | ||
ireec_flags=None, | ||
target=None, | ||
ireec_flags="", | ||
decomp_attn=False, | ||
exit_on_vmfb=False, | ||
pipeline_dir=None, | ||
|
@@ -161,6 +161,9 @@ def export_mmdit_model( | |
weights_only=False, | ||
): | ||
dtype = torch.float16 if precision == "fp16" else torch.float32 | ||
mmdit_model = MMDiTModel( | ||
dtype=dtype, | ||
) | ||
np_dtype = "float16" if precision == "fp16" else "float32" | ||
safe_name = utils.create_safe_name( | ||
hf_model_name, | ||
|
@@ -169,13 +172,14 @@ def export_mmdit_model( | |
if pipeline_dir: | ||
safe_name = os.path.join(pipeline_dir, safe_name) | ||
if decomp_attn == True: | ||
safe_name += "_decomp_attn" | ||
ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False" | ||
|
||
if input_mlir: | ||
vmfb_path = utils.compile_to_vmfb( | ||
input_mlir, | ||
device, | ||
target_triple, | ||
target, | ||
ireec_flags, | ||
safe_name, | ||
mlir_source="file", | ||
|
@@ -208,7 +212,7 @@ def export_mmdit_model( | |
torch.empty(hidden_states_shape, dtype=dtype), | ||
torch.empty(encoder_hidden_states_shape, dtype=dtype), | ||
torch.empty(pooled_projections_shape, dtype=dtype), | ||
torch.empty(init_batch_dim, dtype=dtype), | ||
torch.empty(1, dtype=dtype), | ||
] | ||
|
||
decomp_list = [] | ||
|
@@ -249,7 +253,7 @@ class CompiledMmdit(CompiledModule): | |
hidden_states_shape, | ||
encoder_hidden_states_shape, | ||
pooled_projections_shape, | ||
init_batch_dim, | ||
(1,), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is mmdit not working batched? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this timestep input gets expanded in the model to the batch dimension. |
||
], | ||
"input_dtypes": [np_dtype for x in range(4)], | ||
"output_shapes": [hidden_states_shape], | ||
|
@@ -263,7 +267,7 @@ class CompiledMmdit(CompiledModule): | |
vmfb_path = utils.compile_to_vmfb( | ||
module_str, | ||
device, | ||
target_triple, | ||
target, | ||
ireec_flags, | ||
safe_name, | ||
return_path=True, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented code