Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstreaming MLPerf punet changes, server/harness support. #799

Open
wants to merge 92 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
a3376d9
Bump punet revision to d30d6ff
eagarvey-amd Jul 12, 2024
7cabac0
Enable punet t2i test.
eagarvey-amd Jul 12, 2024
7dfd4c8
Use formatted strings as input to printer.
eagarvey-amd Jul 12, 2024
1cd3ee9
Rework sdxl test to setup with a pipeline, fix unloading submodels, f…
eagarvey-amd Jul 12, 2024
1a90abd
Add switch for punet preprocessing flags
eagarvey-amd Jul 13, 2024
b70318d
Xfail punet e2e test.
eagarvey-amd Jul 13, 2024
2d7ebcd
Fixups to sdxl test arguments
eagarvey-amd Jul 15, 2024
feebc87
Fix flagset arg and enable vae encode.
eagarvey-amd Jul 15, 2024
af7782b
Enable VAE encode validation, mark as xfail
eagarvey-amd Jul 15, 2024
eff59a9
Fix formatting
eagarvey-amd Jul 15, 2024
63fb053
fix runner function name in old sd test.
eagarvey-amd Jul 15, 2024
aff48ab
Fix xfail syntax.
eagarvey-amd Jul 15, 2024
b10ad8d
Update unet script for compile function signature change
eagarvey-amd Jul 15, 2024
321d21d
Update punet to 4d4f955
IanNod Jul 16, 2024
2de912e
Disable vulkan test on MI250 runner.
monorimet Jul 16, 2024
9fdc07f
Change tqdm disable conditions and deepcopy model map on init.
eagarvey-amd Jul 17, 2024
b20be32
Don't break workarounds for model path
monorimet Jul 17, 2024
02705a9
Fix for passing a path as attn_spec.
eagarvey-amd Jul 18, 2024
9229aed
Bump punet revision to defeb489fe2bb17b77d587924db9e58048a8c140
eagarvey-amd Jul 19, 2024
f09ef4a
Move JIT cpu scheduling load helpers inside conditional.
eagarvey-amd Jul 19, 2024
bbcc424
formatting
eagarvey-amd Jul 19, 2024
1f19c7f
Don't pass benchmark as an export arg.
eagarvey-amd Jul 19, 2024
39c0c00
Changes so no external downloads. (#781)
saienduri Jul 19, 2024
3c59b25
fix so that we check exact paths as well for is_prepared (#782)
saienduri Jul 19, 2024
2e9de46
Update punet to 60edc91
IanNod Jul 20, 2024
aa0ac2b
Vae weight path none check (#784)
saienduri Jul 21, 2024
6556a36
Bump punet to mi300_all_sym_8_step10 (62785ea)
monorimet Jul 22, 2024
2c49cb6
Changes so that the default run without quant docker will work as wel…
saienduri Jul 22, 2024
cb911b1
Bump punet to 361df65844e0a7c766484707c57f6248cea9587f
eagarvey-amd Jul 22, 2024
d857f77
Sync flags to sdxl-scripts repo (#786)
saienduri Jul 23, 2024
37548f2
Integrate int8 tk kernels (#783)
nithinsubbiah Jul 23, 2024
25b2462
Update punet revision to deterministic version (42e9407)
monorimet Jul 23, 2024
0e57b4e
Integration of tk kernels into pipeline (#789)
saienduri Jul 24, 2024
920dbf5
Update unet horizontal fusion flag (#790)
saienduri Jul 25, 2024
6f16731
Revert "Update unet horizontal fusion flag (#790)"
saienduri Jul 25, 2024
15dbd93
[tk kernel] Add support to match kernel with number of arguments and …
nithinsubbiah Jul 25, 2024
0c02652
Add functionality to SD pipeline and abstracted components for saving…
monorimet Jul 25, 2024
3fd954b
Remove download links for tk kernels and instead specify kernel direc…
nithinsubbiah Jul 25, 2024
7f8a2b0
Update to best iteration on unet weights (#794)
saienduri Jul 25, 2024
bf63aec
Add missing tk_kernel_args arg in function calls (#795)
nithinsubbiah Jul 25, 2024
a74d98e
update hash for config file
saienduri Jul 25, 2024
925cd0c
Fix formatting
eagarvey-amd Jul 29, 2024
7715fd0
Point to sdxl-vae-fix branch of iree-turbine.
eagarvey-amd Jul 30, 2024
e276c78
Add SD3 to sd_pipeline
eagarvey-amd Jul 30, 2024
de5d3de
Update test_models.yml
monorimet Jul 30, 2024
d0d3ae6
Remove default in mmdit export args.
eagarvey-amd Jul 30, 2024
403fe47
set vae_harness to False in sdxl test.
eagarvey-amd Jul 30, 2024
0ac6b64
Switch to main branch of iree-turbine
eagarvey-amd Jul 30, 2024
1a41394
Update sd3_vae.py
monorimet Aug 1, 2024
493f260
Remove preprocess arg that fails to parse.
monorimet Aug 1, 2024
711403c
SD3 updates, CLI arguments for multi-device
eagarvey-amd Aug 2, 2024
e554da8
Tweaks to requirements, scheduler filenames
eagarvey-amd Aug 2, 2024
cdd2f66
xfail stateless llama test
monorimet Aug 9, 2024
d23a45b
Flag updates and parametrize a few more args.
eagarvey-amd Aug 13, 2024
7ecfece
Merge branch 'merge_punet_sdxl' of https://github.com/nod-ai/SHARK-Tu…
eagarvey-amd Aug 13, 2024
2d7a92e
Update SDXL tests, README for running on GFX942
eagarvey-amd Aug 15, 2024
18bffdb
Fix vae script CLI and revert precision changes to sd3 text encoders …
eagarvey-amd Aug 17, 2024
df85dca
Merge branch 'merge_punet_sdxl' of https://github.com/nod-ai/SHARK-Tu…
eagarvey-amd Aug 17, 2024
674128e
Small fixes to compile modes and requirements
eagarvey-amd Aug 27, 2024
4d6198b
Adds explicit model arch flag, remove commented code
eagarvey-amd Aug 28, 2024
f3e3fe3
Fix formatting
eagarvey-amd Aug 28, 2024
2ed8037
Merge branch 'main' into merge_punet_sdxl
monorimet Aug 28, 2024
7adfc7a
Fix formatting
eagarvey-amd Aug 28, 2024
ff2c3c9
Update test_models.yml
monorimet Sep 10, 2024
afdb8d6
Decompose CLIP attention
eagarvey-amd Sep 10, 2024
a4e67e8
decompose implementation for clip
eagarvey-amd Sep 11, 2024
35517d9
Add decompose clip flag to pipe e2e test
eagarvey-amd Sep 11, 2024
6ca109a
Add attention decomposition mechanism to sdxl clip exports.
eagarvey-amd Sep 11, 2024
453fb38
Update compile options for sdxl
eagarvey-amd Sep 11, 2024
c0be575
Decompose VAE for cpu
eagarvey-amd Sep 12, 2024
e3cd69d
skip i8 punet test on cpu
eagarvey-amd Sep 12, 2024
e3e1dcb
Don't use spec for clip by default
eagarvey-amd Sep 12, 2024
56d6ee7
Revert change to attention spec handling in sdxl test
monorimet Sep 13, 2024
d330564
Don't use td spec for clip bs2 export test
monorimet Sep 13, 2024
ffba3ea
disable attn spec usage for sdxl bs2 on mi250 tests
monorimet Sep 13, 2024
fad7e6e
Update test_models.yml
monorimet Sep 14, 2024
05fa32d
Update test_models.yml
monorimet Sep 16, 2024
0291d43
Small fixes to SDXL inference pipeline/exports/compile
eagarvey-amd Sep 24, 2024
e337f2a
Pin torch to 2.4.1
eagarvey-amd Oct 2, 2024
0fd8ad0
Largely disables attn spec usage.
eagarvey-amd Oct 3, 2024
e1c4ac2
Update canonicalization pass name, decouple model validation from pip…
eagarvey-amd Oct 3, 2024
61bb4ef
Don't use punet spec.
eagarvey-amd Oct 3, 2024
dfb9474
Remove default/mfma/wmma specs from sd compile utils.
eagarvey-amd Oct 3, 2024
9fe20a6
Guard path check for attn spec
eagarvey-amd Oct 4, 2024
f39b2d2
Separate punet run
eagarvey-amd Oct 4, 2024
d3c8e80
typo fixes
eagarvey-amd Oct 4, 2024
40808db
Filename fixes, explicit input dtypes for i8 punet
eagarvey-amd Oct 4, 2024
e630d39
Update CPU test configuration.
eagarvey-amd Oct 4, 2024
fc6d018
Decompose VAE for cpu
eagarvey-amd Oct 4, 2024
7d50dc8
Change compile flag reporting to CLI input
eagarvey-amd Oct 4, 2024
f140926
formatting
eagarvey-amd Oct 4, 2024
67e6558
Rework prompt encoder export on aot.export API
eagarvey-amd Oct 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ jobs:
source turbine_venv/bin/activate

pytest -v models/turbine_models/tests/sd_test.py
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2
pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 2
18 changes: 4 additions & 14 deletions models/README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,19 @@
# LLAMA 2 Inference
# Turbine-Models setup (source)

This example require some extra dependencies. Here's an easy way to get it running on a fresh server.

Don't forget to put in your huggingface token from https://huggingface.co/settings/tokens
For private/gated models, make sure you have run `huggingface-cli login`.

```bash
#!/bin/bash


# if you don't insert it, you will be prompted to log in later;
# you may need to rerun this script after logging in
YOUR_HF_TOKEN="insert token for headless"

# clone and install dependencies
sudo apt install -y git
git clone https://github.com/nod-ai/SHARK-Turbine.git
cd SHARK-Turbine
pip install -r core/requirements.txt
pip install torch==2.5.0.dev20240801 torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r models/requirements.txt

# do an editable install from the cloned SHARK-Turbine
pip install --editable core models
pip install --editable models

# Log in with Hugging Face CLI if token setup is required
if [[ $YOUR_HF_TOKEN == hf_* ]]; then
Expand All @@ -42,6 +35,3 @@ else
huggingface-cli login
fi

# Step 7: Run the Python script
python .\python\turbine_models\custom_models\stateless_llama.py --compile_to=torch --external_weights=safetensors --external_weight_file=llama_f32.safetensors
```
4 changes: 2 additions & 2 deletions models/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
protobuf
gguf
transformers==4.37.1
transformers==4.43.3
torchsde
accelerate
peft
Expand All @@ -13,4 +13,4 @@ einops
pytest
scipy
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
-e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank
-e git+https://github.com/nod-ai/sharktank.git@main#egg=sharktank&subdirectory=sharktank
2 changes: 1 addition & 1 deletion models/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_version_info():
"Shark-Turbine",
"protobuf",
"sentencepiece",
"transformers>=4.37.1",
"transformers>=4.43.3",
"accelerate",
"diffusers==0.29.0.dev0",
"azure-storage-blob",
Expand Down
46 changes: 42 additions & 4 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,24 @@ class PipelineComponent:
"""

def __init__(
self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False
self,
printer,
dest_type="devicearray",
dest_dtype="float16",
benchmark=False,
save_outputs=False,
):
self.runner = None
self.module_name = None
self.device = None
self.metadata = None
self.printer = printer
self.benchmark = benchmark
self.save_outputs = save_outputs
self.output_counter = 0
self.dest_type = dest_type
self.dest_dtype = dest_dtype
self.validate = False

def load(
self,
Expand Down Expand Up @@ -218,6 +226,16 @@ def _output_cast(self, output):
case _:
return output

def save_output(self, function_name, output):
if isinstance(output, tuple) or isinstance(output, list):
for i in output:
self.save_output(function_name, i)
else:
np.save(
f"{function_name}_output_{self.output_counter}.npy", output.to_host()
)
self.output_counter += 1

def _run(self, function_name, inputs: list):
return self.module[function_name](*inputs)

Expand All @@ -235,13 +253,21 @@ def __call__(self, function_name, inputs: list):
if not isinstance(inputs, list):
inputs = [inputs]
inputs = self._validate_or_convert_inputs(function_name, inputs)

if self.validate:
self.save_torch_inputs(inputs)

if self.benchmark:
output = self._run_and_benchmark(function_name, inputs)
else:
output = self._run(function_name, inputs)
if self.save_outputs:
self.save_output(function_name, output)
output = self._output_cast(output)
return output

# def _run_and_validate(self, iree_fn, torch_fn, inputs: list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code



class Printer:
def __init__(self, verbose, start_time, print_time):
Expand Down Expand Up @@ -340,6 +366,7 @@ def __init__(
hf_model_name: str | dict[str] = None,
benchmark: bool | dict[bool] = False,
verbose: bool = False,
save_outputs: bool | dict[bool] = False,
common_export_args: dict = {},
):
self.map = model_map
Expand Down Expand Up @@ -374,6 +401,7 @@ def __init__(
"external_weights": external_weights,
"hf_model_name": hf_model_name,
"benchmark": benchmark,
"save_outputs": save_outputs,
}
for arg in map_arguments.keys():
self.map = merge_arg_into_map(self.map, map_arguments[arg], arg)
Expand All @@ -391,7 +419,8 @@ def __init__(
)
for submodel in self.map.keys():
for key, value in map_arguments.items():
self.map = merge_export_arg(self.map, value, key)
if key not in ["benchmark", "save_outputs"]:
self.map = merge_export_arg(self.map, value, key)
for key, value in self.map[submodel].get("export_args", {}).items():
if key == "hf_model_name":
self.map[submodel]["keywords"].append(
Expand Down Expand Up @@ -539,7 +568,11 @@ def is_prepared(self, vmfbs, weights):
avail_files = os.listdir(self.external_weights_dir)
candidates = []
for filename in avail_files:
if all(str(x) in filename for x in w_keywords):
if all(
str(x) in filename
or str(x) == os.path.join(self.external_weights_dir, filename)
for x in w_keywords
):
candidates.append(
os.path.join(self.external_weights_dir, filename)
)
Expand Down Expand Up @@ -723,7 +756,7 @@ def export_submodel(
def load_map(self):
for submodel in self.map.keys():
if not self.map[submodel]["load"]:
self.printer.print("Skipping load for ", submodel)
self.printer.print(f"Skipping load for {submodel}")
continue
self.load_submodel(submodel)

Expand All @@ -739,6 +772,7 @@ def load_submodel(self, submodel):
printer=self.printer,
dest_type=dest_type,
benchmark=self.map[submodel].get("benchmark", False),
save_outputs=self.map[submodel].get("save_outputs", False),
)
self.map[submodel]["runner"].load(
self.map[submodel]["driver"],
Expand All @@ -751,6 +785,10 @@ def load_submodel(self, submodel):

def unload_submodel(self, submodel):
self.map[submodel]["runner"].unload()
self.map[submodel]["vmfb"] = None
self.map[submodel]["mlir"] = None
self.map[submodel]["weights"] = None
self.map[submodel]["export_args"]["input_mlir"] = None
setattr(self, submodel, None)


Expand Down
49 changes: 49 additions & 0 deletions models/turbine_models/custom_models/sd3_inference/diffusers_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from diffusers import StableDiffusion3Pipeline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we integrate this with a test to output the image numerics we can compare against? I know you saw some significant different numerics between cpu and different gpu backends where this may be difficult to directly compare, maybe some FID/CLIP scores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a faithful comparison with diffusers reference is a larger problem -- we are really best off investing in getting real CLIP/FID scores with a validation dataset. This diffusers reference is really just a hold-over/sanity check for now; I don't even trust it to give us a decent baseline from CPU. We can leave this out for now but I'd rather keep it just to have something ready for comparison with diffusers on ROCM/CUDA

import torch
from datetime import datetime as dt


def run_diffusers_cpu(
hf_model_name,
prompt,
negative_prompt,
guidance_scale,
seed,
height,
width,
num_inference_steps,
):
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
hf_model_name, torch_dtype=torch.float32
)
pipe = pipe.to("cpu")
generator = torch.Generator().manual_seed(int(seed))

image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
generator=generator,
).images[0]
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
image.save(f"diffusers_reference_output_{timestamp}.png")


if __name__ == "__main__":
from turbine_models.custom_models.sd_inference.sd_cmd_opts import args

run_diffusers_cpu(
args.hf_model_name,
args.prompt,
args.negative_prompt,
args.guidance_scale,
args.seed,
args.height,
args.width,
args.num_inference_steps,
)
22 changes: 13 additions & 9 deletions models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def forward(
pooled_projections,
timestep,
):
timestep.expand(hidden_states.shape[0])
noise_pred = self.mmdit(
hidden_states,
encoder_hidden_states,
Expand All @@ -71,7 +72,7 @@ def forward(self, q, k, v):
def export_attn(
precision="fp16",
device="cpu",
target_triple="x86_64-unknown-linux-gnu",
target="x86_64-unknown-linux-gnu",
ireec_flags="",
compile_to="torch",
decomp_attn=False,
Expand Down Expand Up @@ -128,7 +129,7 @@ class CompiledAttn(CompiledModule):
vmfb_path = utils.compile_to_vmfb(
module_str,
device,
target_triple,
target,
ireec_flags,
safe_name,
return_path=True,
Expand All @@ -139,7 +140,6 @@ class CompiledAttn(CompiledModule):

@torch.no_grad()
def export_mmdit_model(
mmdit_model,
hf_model_name,
batch_size,
height,
Expand All @@ -151,8 +151,8 @@ def export_mmdit_model(
external_weights=None,
external_weight_path=None,
device=None,
target_triple=None,
ireec_flags=None,
target=None,
ireec_flags="",
decomp_attn=False,
exit_on_vmfb=False,
pipeline_dir=None,
Expand All @@ -161,6 +161,9 @@ def export_mmdit_model(
weights_only=False,
):
dtype = torch.float16 if precision == "fp16" else torch.float32
mmdit_model = MMDiTModel(
dtype=dtype,
)
np_dtype = "float16" if precision == "fp16" else "float32"
safe_name = utils.create_safe_name(
hf_model_name,
Expand All @@ -169,13 +172,14 @@ def export_mmdit_model(
if pipeline_dir:
safe_name = os.path.join(pipeline_dir, safe_name)
if decomp_attn == True:
safe_name += "_decomp_attn"
ireec_flags += ",--iree-opt-aggressively-propagate-transposes=False"

if input_mlir:
vmfb_path = utils.compile_to_vmfb(
input_mlir,
device,
target_triple,
target,
ireec_flags,
safe_name,
mlir_source="file",
Expand Down Expand Up @@ -208,7 +212,7 @@ def export_mmdit_model(
torch.empty(hidden_states_shape, dtype=dtype),
torch.empty(encoder_hidden_states_shape, dtype=dtype),
torch.empty(pooled_projections_shape, dtype=dtype),
torch.empty(init_batch_dim, dtype=dtype),
torch.empty(1, dtype=dtype),
]

decomp_list = []
Expand Down Expand Up @@ -249,7 +253,7 @@ class CompiledMmdit(CompiledModule):
hidden_states_shape,
encoder_hidden_states_shape,
pooled_projections_shape,
init_batch_dim,
(1,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is mmdit not working batched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this timestep input gets expanded in the model to the batch dimension.

],
"input_dtypes": [np_dtype for x in range(4)],
"output_shapes": [hidden_states_shape],
Expand All @@ -263,7 +267,7 @@ class CompiledMmdit(CompiledModule):
vmfb_path = utils.compile_to_vmfb(
module_str,
device,
target_triple,
target,
ireec_flags,
safe_name,
return_path=True,
Expand Down
Loading
Loading