Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add end to end llama test, including generating and running vmfb #224

Merged
merged 29 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
27e68b4
add end to end llama test, including generating and running vmfb
Dec 7, 2023
d6a29dd
adjust naming of tests to look clearer
Dec 7, 2023
945e9ea
fix formatting with black
Dec 7, 2023
1981cd6
Fixes cpu flag update for stateless llama from iree bump (#226)
IanNod Dec 7, 2023
b5a6192
Stable Diffusion using aot.export and external parameters (#217)
aviator19941 Dec 7, 2023
f1aa879
fold run_vmfb_comparison into python/turbine_models/tests and actuall…
Dec 7, 2023
9395248
Adds tests for gen_external_params.py quantize function (#225)
IanNod Dec 7, 2023
f6c8008
remove python-fire dependency
Dec 7, 2023
612556a
remove unnecessary vulcan max_alloc and rename for consistency betwee…
Dec 7, 2023
200dbe5
black
Dec 7, 2023
4f58d78
resolve merge conflicts
Dec 7, 2023
3ab45d1
adjust naming of tests to look clearer
Dec 7, 2023
e552a1a
fix formatting with black
Dec 7, 2023
23fd558
fold run_vmfb_comparison into python/turbine_models/tests and actuall…
Dec 7, 2023
8b32886
remove python-fire dependency
Dec 7, 2023
75a4fc7
remove unnecessary vulcan max_alloc and rename for consistency betwee…
Dec 7, 2023
48e2685
black
Dec 7, 2023
e139f82
resolve merge conflict
Dec 7, 2023
476b60a
fix discrepancy between vmfb and torch due to one being quantized f16…
Dec 7, 2023
2289373
finally test cases passed and black applied
Dec 8, 2023
03c022b
make llama and sd test separate steps
Dec 8, 2023
c00bc24
typo
Dec 8, 2023
91b3b72
show mem availability
Dec 8, 2023
5e31126
fix device issue
renxida Dec 8, 2023
cf4b4b1
move args back to beginnign of file
renxida Dec 11, 2023
979321c
remove MAX_STEP_SEQ which is not needed for VMFB_COMPARISON
renxida Dec 11, 2023
b000642
black
renxida Dec 11, 2023
ad865e2
only parse args when in main
renxida Dec 11, 2023
3883224
black
renxida Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Test
name: Test Turbine Models

renxida marked this conversation as resolved.
Show resolved Hide resolved
on:
workflow_dispatch:
Expand All @@ -8,7 +8,7 @@ on:
- main

jobs:
test:
test-turbine-models:
strategy:
matrix:
version: [3.11]
Expand Down Expand Up @@ -36,7 +36,15 @@ jobs:
pip install --upgrade -r requirements.txt
pip install -e .[testing]
pip install -e python/turbine_models

- name: Show current free memory
run: |
free -mh

- name: Run stateless_llama tests
run: |
pytest python/turbine_models/tests/stateless_llama_test.py

- name: Run tests
- name: Run sd tests
run: |
pytest python/turbine_models/tests
pytest python/turbine_models/tests/sd_test.py
1 change: 1 addition & 0 deletions python/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
torch.ops.aten._to_copy,
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
]


Expand Down
10 changes: 10 additions & 0 deletions python/shark_turbine/importers/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,16 @@ def _import_torch_op_overload(
elif target == torch.ops.aten.lift_fresh_copy.out:
node.target = target = torch.ops.aten.clone.out
node.args = (node.args[0], None, node.args[1])
# TODO: generalize empty.memory_format in the future
# Currently, the aten.baddbmm.default op for Unet includes multiplying an
# empty.memory_format input with a constant, which creates NaN values
# because empty.memory_format contains uninitialized data. Converting
# aten.baddbmm.default -> aten.zeros.default fixes the correctness issue
elif target == torch.ops.aten.empty.memory_format:
if len(node.users) == 1:
for key_node in node.users:
if key_node.target == torch.ops.aten.baddbmm.default:
node.target = target = torch.ops.aten.zeros.default

schema = target._schema
assert isinstance(schema, FunctionSchema)
Expand Down
201 changes: 201 additions & 0 deletions python/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys
import re

from iree import runtime as ireert
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
from transformers import CLIPTextModel, CLIPTokenizer

import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument("--run_vmfb", action="store_true")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_file", type=str, default="")
parser.add_argument("--vmfb_path", type=str, default="")
parser.add_argument(
"--external_weights",
type=str,
default=None,
help="saves ir/vmfb without global weights for size and readability, options [safetensors]",
)
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
# TODO: Bring in detection for target triple
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")

prompt = ["a photograph of an astronaut riding a horse"]


def export_clip_model(
hf_model_name,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_file=None,
device=None,
target_triple=None,
max_alloc=None,
):
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder="text_encoder",
token=hf_auth_token,
)

mapper = {}
utils.save_external_weights(
mapper, text_encoder_model, external_weights, external_weight_file
)

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if compile_to != "vmfb":
return module_str, tokenizer
else:
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)


def run_clip_vmfb_comparison(args):
config = ireert.Config(args.device)

if args.external_weight_file:
index = ireert.ParameterIndex()
index.load(args.external_weight_file)

safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
if args.vmfb_path:
mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path)
elif os.path.exists(f"{safe_name}.vmfb"):
mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb")
else:
sys.exit("no vmfb_path provided, required for run_vmfb")

vm_modules = [
mod,
ireert.create_hal_module(config.vm_instance, config.device),
]
if args.external_weight_file:
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.insert(0, param_module)

ctx = ireert.SystemContext(
vm_modules=vm_modules,
config=config,
)
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_name,
subfolder="tokenizer",
token=args.hf_auth_token,
)
text_input = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
inp = text_input.input_ids
device_inputs = [ireert.asdevicearray(config.device, inp)]

# Turbine output
ModuleCompiled = ctx.modules.compiled_clip
turbine_outputs = ModuleCompiled["main"](*device_inputs)
turbine_output = turbine_outputs[0]
print(
"TURBINE OUTPUT:",
turbine_output.to_host(),
turbine_output.to_host().shape,
turbine_output.to_host().dtype,
)

# Torch output
text_encoder_model = CLIPTextModel.from_pretrained(
args.hf_model_name,
subfolder="text_encoder",
token=args.hf_auth_token,
)
torch_output = text_encoder_model.forward(inp)[0]
np_torch_output = torch_output.detach().cpu().numpy()
print(
"TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype
)

err = utils.largest_error(np_torch_output, turbine_output)
print("LARGEST ERROR:", err)
assert err < 9e-5


if __name__ == "__main__":
args = parser.parse_args()
if args.run_vmfb:
run_clip_vmfb_comparison(args)
else:
mod_str, _ = export_clip_model(
args.hf_model_name,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_file,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = args.hf_model_name.split("/")[-1].strip()
safe_name = re.sub("-", "_", safe_name)
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
Loading
Loading