From 27e68b4203c742cc8fbbbc70c223b19ea416c00c Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 01:06:10 +0000 Subject: [PATCH 01/28] add end to end llama test, including generating and running vmfb --- .github/workflows/test_models.yml | 2 +- .../custom_models/stateless_llama.py | 13 ++--- .../gen_external_params.py | 45 ++++++++++----- python/turbine_models/tests/conftest.py | 18 ++++++ python/turbine_models/tests/llama_test.py | 30 ---------- .../tests/stateless_llama_test.py | 57 +++++++++++++++++++ 6 files changed, 113 insertions(+), 52 deletions(-) create mode 100644 python/turbine_models/tests/conftest.py delete mode 100644 python/turbine_models/tests/llama_test.py create mode 100644 python/turbine_models/tests/stateless_llama_test.py diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 3227c2769..36da25eea 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -1,4 +1,4 @@ -name: Test +name: turbine_models test on: workflow_dispatch: diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 0d3d749d2..d7144e8cb 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -48,12 +48,12 @@ "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +parser.add_argument("--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm") # TODO: Bring in detection for target triple parser.add_argument( "--iree_target_triple", type=str, - default="", + default="host", help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") @@ -236,9 +236,8 @@ def forward(token0: torch.Tensor, *state0_flat): "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-opt-const-expr-hoisting=False", ] - if device == "cpu": - flags.append("--iree-llvmcpu-enable-microkernels") - device = "llvm-cpu" + if device == "llvm-cpu": + flags.append("--iree-llvmcpu-enable-ukernels=all") elif device == "vulkan": flags.extend( [ @@ -277,11 +276,11 @@ def forward(token0: torch.Tensor, *state0_flat): with open(f"{safe_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) print("saved to ", safe_name + ".vmfb") - exit() + return module_str, tokenizer def run_vmfb_comparison(args): - config = ireert.Config(args.device) + config = ireert.Config("local-task") if args.external_weight_file: index = ireert.ParameterIndex() diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 7e993de72..de684e9f5 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -72,35 +72,52 @@ def forward(self, x): all_weights.update(int_weights) return all_weights +def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", + quantization:str = "int4", + weight_path:str = "", + hf_auth_token:str = None, + precision:str = "f16"): + """ + Main function to run the model quantization and saving process. -if __name__ == "__main__": - args = parser.parse_args() + :param hf_model_name: The Hugging Face model name ID. + :param quantization: Type of quantization to apply ('int4' or 'int8'). + :param weight_path: Path to save the quantized model weights. + :param hf_auth_token: The Hugging Face auth token required for some models. + :param precision: Data type of model ('f16' or 'f32'). + """ model_builder = HFTransformerBuilder( example_input=None, - hf_id=args.hf_model_name, + hf_id=hf_model_name, auto_model=AutoModelForCausalLM, - hf_auth_token=args.hf_auth_token, + hf_auth_token=hf_auth_token, ) model_builder.build_model() - if args.precision == "f16": + + if precision == "f16": model = model_builder.model.half() dtype = torch.float16 - elif args.precision == "f32": + elif precision == "f32": model = model_builder.model dtype = torch.float32 else: - sys.exit("invalid precision, f16 or f32 supported") - quant_weights = quantize(model, args.quantization, dtype) - # TODO: Add more than just safetensor support - import safetensors + sys.exit("Invalid precision, f16 or f32 supported") + + quant_weights = quantize(model, quantization, dtype) - if args.weight_path == "": - save_path = args.hf_model_name.split("/")[-1].strip() + if weight_path == "": + save_path = hf_model_name.split("/")[-1].strip() save_path = re.sub("-", "_", save_path) save_path = ( - save_path + "_" + args.precision + "_" + args.quantization + ".safetensors" + save_path + "_" + precision + "_" + quantization + ".safetensors" ) else: - save_path = args.weight_path + save_path = weight_path + + import safetensors safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) + +if __name__ == "__main__": + import fire + fire.Fire(gen_external_params) \ No newline at end of file diff --git a/python/turbine_models/tests/conftest.py b/python/turbine_models/tests/conftest.py new file mode 100644 index 000000000..ae6071971 --- /dev/null +++ b/python/turbine_models/tests/conftest.py @@ -0,0 +1,18 @@ +def pytest_addoption(parser): + parser.addoption("--all", action="store_true", help="run all combinations") + + +def pytest_generate_tests(metafunc): + if "quantization" in metafunc.fixturenames: + if metafunc.config.getoption("all"): + quantizations = ["int4", None] + else: + quantizations = ["int4"] + metafunc.parametrize("quantization", quantizations) + + if "precision" in metafunc.fixturenames: + if metafunc.config.getoption("all"): + precisions = ["f16", "f32"] + else: + precisions = ["f16"] + metafunc.parametrize("precision", precisions) \ No newline at end of file diff --git a/python/turbine_models/tests/llama_test.py b/python/turbine_models/tests/llama_test.py deleted file mode 100644 index df44a0e70..000000000 --- a/python/turbine_models/tests/llama_test.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import turbine_models.custom_models.stateless_llama as llama -import unittest -import os - - -class LLamaTest(unittest.TestCase): - def testExportTransformerModel(self): - llama.export_transformer_model( - # This is a public model, so no auth required - "llSourcell/medllama2_7b", - None, - "torch", - "safetensors", - "medllama2_f32.safetensors", - None, - "f32", - ) - os.remove("medllama2_f32.safetensors") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py new file mode 100644 index 000000000..0aaeae6ff --- /dev/null +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import turbine_models.custom_models.stateless_llama as llama +import unittest +import os +import pytest + +from typing import Literal + +def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): + + llama.export_transformer_model( + hf_model_name="llSourcell/medllama2_7b", + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + # external_weight_file="medllama2_7b_f16_int4.safetensors", Do not export weights because this doesn't get quantized + quantization=quantization, + precision=precision, + device="llvm-cpu", + target_triple="host", + max_alloc = "4294967296" + ) + + from turbine_models.gen_external_params.gen_external_params import gen_external_params + gen_external_params( + hf_model_name="llSourcell/medllama2_7b", + quantization=quantization, + weight_path="medllama2_7b_f16_int4.safetensors", + hf_auth_token=None, + precision=precision + ) + + from types import SimpleNamespace + args = SimpleNamespace() + args.hf_model_name = "llSourcell/medllama2_7b" + args.hf_auth_token = None + args.vmfb_path = "medllama2_7b.vmfb" + args.external_weight_file = "medllama2_7b_f16_int4.safetensors" + args.run_vmfb = True + args.device="llvm-cpu" + args.precision = precision + args.quantization = quantization + args.iree_target_triple="host" + llama.run_vmfb_comparison(args) + + + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From d6a29dd125782506cc92050dd8e861aae45b6108 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 01:12:34 +0000 Subject: [PATCH 02/28] adjust naming of tests to look clearer --- .github/workflows/test_models.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 36da25eea..e353273e2 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -1,4 +1,4 @@ -name: turbine_models test +name: Test Turbine Models on: workflow_dispatch: @@ -8,7 +8,7 @@ on: - main jobs: - test: + test-turbine-models: strategy: matrix: version: [3.11] From 945e9eade1cc69ec7066e7f401fd221255ca28f6 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 01:13:28 +0000 Subject: [PATCH 03/28] fix formatting with black --- .../custom_models/stateless_llama.py | 4 +++- python/turbine_models/tests/conftest.py | 4 ++-- .../tests/stateless_llama_test.py | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index d7144e8cb..2f21d5da5 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -48,7 +48,9 @@ "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) -parser.add_argument("--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm") +parser.add_argument( + "--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm" +) # TODO: Bring in detection for target triple parser.add_argument( "--iree_target_triple", diff --git a/python/turbine_models/tests/conftest.py b/python/turbine_models/tests/conftest.py index ae6071971..1b1fe1f67 100644 --- a/python/turbine_models/tests/conftest.py +++ b/python/turbine_models/tests/conftest.py @@ -9,10 +9,10 @@ def pytest_generate_tests(metafunc): else: quantizations = ["int4"] metafunc.parametrize("quantization", quantizations) - + if "precision" in metafunc.fixturenames: if metafunc.config.getoption("all"): precisions = ["f16", "f32"] else: precisions = ["f16"] - metafunc.parametrize("precision", precisions) \ No newline at end of file + metafunc.parametrize("precision", precisions) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 0aaeae6ff..4b957afb3 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -12,8 +12,8 @@ from typing import Literal + def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): - llama.export_transformer_model( hf_model_name="llSourcell/medllama2_7b", hf_auth_token=None, @@ -24,32 +24,34 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " precision=precision, device="llvm-cpu", target_triple="host", - max_alloc = "4294967296" + max_alloc="4294967296", + ) + + from turbine_models.gen_external_params.gen_external_params import ( + gen_external_params, ) - from turbine_models.gen_external_params.gen_external_params import gen_external_params gen_external_params( hf_model_name="llSourcell/medllama2_7b", quantization=quantization, weight_path="medllama2_7b_f16_int4.safetensors", hf_auth_token=None, - precision=precision + precision=precision, ) from types import SimpleNamespace + args = SimpleNamespace() args.hf_model_name = "llSourcell/medllama2_7b" args.hf_auth_token = None args.vmfb_path = "medllama2_7b.vmfb" args.external_weight_file = "medllama2_7b_f16_int4.safetensors" args.run_vmfb = True - args.device="llvm-cpu" + args.device = "llvm-cpu" args.precision = precision args.quantization = quantization - args.iree_target_triple="host" + args.iree_target_triple = "host" llama.run_vmfb_comparison(args) - - if __name__ == "__main__": From 1981cd64d31fa87477e87320fc3431fe2958d06b Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Thu, 7 Dec 2023 09:48:38 -0800 Subject: [PATCH 04/28] Fixes cpu flag update for stateless llama from iree bump (#226) --- python/turbine_models/custom_models/stateless_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 0d3d749d2..61727d673 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -237,7 +237,7 @@ def forward(token0: torch.Tensor, *state0_flat): "--iree-opt-const-expr-hoisting=False", ] if device == "cpu": - flags.append("--iree-llvmcpu-enable-microkernels") + flags.append("--iree-llvmcpu-enable-ukernels=all") device = "llvm-cpu" elif device == "vulkan": flags.extend( From b5a61926b0591e5e822c1b56b248c6b2b4ec4abe Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Thu, 7 Dec 2023 10:15:29 -0800 Subject: [PATCH 05/28] Stable Diffusion using aot.export and external parameters (#217) - Saves weights to .safetensors file - Load weights at runtime with a "stripped" .mlir --- python/shark_turbine/dynamo/passes.py | 1 + python/shark_turbine/importers/fx_importer.py | 10 + .../custom_models/sd_inference/clip.py | 201 ++++++++++++++++ .../custom_models/sd_inference/unet.py | 215 ++++++++++++++++++ .../custom_models/sd_inference/utils.py | 83 +++++++ .../custom_models/sd_inference/vae.py | 180 +++++++++++++++ python/turbine_models/setup.py | 2 + python/turbine_models/tests/llama_test.py | 1 + python/turbine_models/tests/sd_test.py | 101 ++++++++ turbine-models-requirements.txt | 2 + 10 files changed, 796 insertions(+) create mode 100644 python/turbine_models/custom_models/sd_inference/clip.py create mode 100644 python/turbine_models/custom_models/sd_inference/unet.py create mode 100644 python/turbine_models/custom_models/sd_inference/utils.py create mode 100644 python/turbine_models/custom_models/sd_inference/vae.py create mode 100644 python/turbine_models/tests/sd_test.py diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 1e8e2058f..478944df6 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -46,6 +46,7 @@ torch.ops.aten._to_copy, torch.ops.aten._log_softmax_backward_data, torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten._unsafe_index.Tensor, ] diff --git a/python/shark_turbine/importers/fx_importer.py b/python/shark_turbine/importers/fx_importer.py index e2707f35b..76d7f9978 100644 --- a/python/shark_turbine/importers/fx_importer.py +++ b/python/shark_turbine/importers/fx_importer.py @@ -628,6 +628,16 @@ def _import_torch_op_overload( elif target == torch.ops.aten.lift_fresh_copy.out: node.target = target = torch.ops.aten.clone.out node.args = (node.args[0], None, node.args[1]) + # TODO: generalize empty.memory_format in the future + # Currently, the aten.baddbmm.default op for Unet includes multiplying an + # empty.memory_format input with a constant, which creates NaN values + # because empty.memory_format contains uninitialized data. Converting + # aten.baddbmm.default -> aten.zeros.default fixes the correctness issue + elif target == torch.ops.aten.empty.memory_format: + if len(node.users) == 1: + for key_node in node.users: + if key_node.target == torch.ops.aten.baddbmm.default: + node.target = target = torch.ops.aten.zeros.default schema = target._schema assert isinstance(schema, FunctionSchema) diff --git a/python/turbine_models/custom_models/sd_inference/clip.py b/python/turbine_models/custom_models/sd_inference/clip.py new file mode 100644 index 000000000..ec2dcb3fb --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/clip.py @@ -0,0 +1,201 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import re + +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from transformers import CLIPTextModel, CLIPTokenizer + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument("--run_vmfb", action="store_true") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_file", type=str, default="") +parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + +prompt = ["a photograph of an astronaut riding a horse"] + + +def export_clip_model( + hf_model_name, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_file=None, + device=None, + target_triple=None, + max_alloc=None, +): + # Load the tokenizer and text encoder to tokenize and encode the text. + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + text_encoder_model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder="text_encoder", + token=hf_auth_token, + ) + + mapper = {} + utils.save_external_weights( + mapper, text_encoder_model, external_weights, external_weight_file + ) + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): + return jittable(text_encoder_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str, tokenizer + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +def run_clip_vmfb_comparison(args): + config = ireert.Config(args.device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if args.vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + sys.exit("no vmfb_path provided, required for run_vmfb") + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] + if args.external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=config, + ) + tokenizer = CLIPTokenizer.from_pretrained( + args.hf_model_name, + subfolder="tokenizer", + token=args.hf_auth_token, + ) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + inp = text_input.input_ids + device_inputs = [ireert.asdevicearray(config.device, inp)] + + # Turbine output + ModuleCompiled = ctx.modules.compiled_clip + turbine_outputs = ModuleCompiled["main"](*device_inputs) + turbine_output = turbine_outputs[0] + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + # Torch output + text_encoder_model = CLIPTextModel.from_pretrained( + args.hf_model_name, + subfolder="text_encoder", + token=args.hf_auth_token, + ) + torch_output = text_encoder_model.forward(inp)[0] + np_torch_output = torch_output.detach().cpu().numpy() + print( + "TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype + ) + + err = utils.largest_error(np_torch_output, turbine_output) + print("LARGEST ERROR:", err) + assert err < 9e-5 + + +if __name__ == "__main__": + args = parser.parse_args() + if args.run_vmfb: + run_clip_vmfb_comparison(args) + else: + mod_str, _ = export_clip_model( + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/unet.py b/python/turbine_models/custom_models/sd_inference/unet.py new file mode 100644 index 000000000..4c4d3c227 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/unet.py @@ -0,0 +1,215 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import re + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import UNet2DConditionModel + +import safetensors +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument("--run_vmfb", action="store_true") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_file", type=str, default="") +parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + + +class UnetModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + token=hf_auth_token, + ) + self.guidance_scale = 7.5 + + def forward(self, sample, timestep, encoder_hidden_states): + samples = torch.cat([sample] * 2) + unet_out = self.unet.forward( + samples, timestep, encoder_hidden_states, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + +def export_unet_model( + unet_model, + hf_model_name, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_file=None, + device=None, + target_triple=None, + max_alloc=None, +): + mapper = {} + utils.save_external_weights( + mapper, unet_model, external_weights, external_weight_file + ) + + encoder_hidden_states_sizes = (2, 77, 768) + if hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states_sizes = (2, 77, 1024) + + class CompiledUnet(CompiledModule): + if external_weights: + params = export_parameters( + unet_model, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(unet_model) + + def main( + self, + sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32), + timestep=AbstractTensor(1, dtype=torch.float32), + encoder_hidden_states=AbstractTensor( + *encoder_hidden_states_sizes, dtype=torch.float32 + ), + ): + return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledUnet(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +def run_unet_vmfb_comparison(unet_model, args): + config = ireert.Config(args.device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if args.vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + sys.exit("no vmfb_path provided, required for run_vmfb") + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] + if args.external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=config, + ) + sample = torch.rand(1, 4, 64, 64, dtype=torch.float32) + timestep = torch.zeros(1, dtype=torch.float32) + if args.hf_model_name == "CompVis/stable-diffusion-v1-4": + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + + device_inputs = [ + ireert.asdevicearray(config.device, sample), + ireert.asdevicearray(config.device, timestep), + ireert.asdevicearray(config.device, encoder_hidden_states), + ] + + # Turbine output + ModuleCompiled = ctx.modules.compiled_unet + turbine_output = ModuleCompiled["main"](*device_inputs) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + # Torch output + torch_output = unet_model.forward(sample, timestep, encoder_hidden_states) + np_torch_output = torch_output.detach().cpu().numpy() + print( + "TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype + ) + + err = utils.largest_error(np_torch_output, turbine_output) + print("LARGEST ERROR:", err) + assert err < 9e-5 + + +if __name__ == "__main__": + args = parser.parse_args() + unet_model = UnetModel( + args.hf_model_name, + args.hf_auth_token, + ) + if args.run_vmfb: + run_unet_vmfb_comparison(unet_model, args) + else: + mod_str = export_unet_model( + unet_model, + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/utils.py b/python/turbine_models/custom_models/sd_inference/utils.py new file mode 100644 index 000000000..e2f4cfd96 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/utils.py @@ -0,0 +1,83 @@ +import iree.compiler as ireec +import numpy as np +import safetensors + + +def save_external_weights( + mapper, + model, + external_weights=None, + external_weight_file=None, +): + if external_weights is not None: + if external_weights == "safetensors": + mod_params = dict(model.named_parameters()) + for name in mod_params: + mapper["params." + name] = name + if external_weight_file: + safetensors.torch.save_file(mod_params, external_weight_file) + print("Saved params to", external_weight_file) + + +def largest_error(array1, array2): + absolute_diff = np.abs(array1 - array2) + max_error = np.max(absolute_diff) + return max_error + + +def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): + flags = [ + "--iree-input-type=torch", + "--mlir-print-debuginfo", + "--mlir-print-op-on-diagnostic=false", + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-target-triple=x86_64-linux-gnu", + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-opt-const-expr-hoisting=False", + ] + if device == "cpu": + flags.append("--iree-llvmcpu-enable-ukernels=all") + device = "llvm-cpu" + elif device == "vulkan": + flags.extend( + [ + "--iree-hal-target-backends=vulkan-spirv", + "--iree-vulkan-target-triple=" + target_triple, + "--iree-stream-resource-max-allocation-size=" + max_alloc, + ] + ) + elif device == "rocm": + flags.extend( + [ + "--iree-hal-target-backends=rocm", + "--iree-rocm-target-chip=" + target_triple, + "--iree-rocm-link-bc=true", + "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", + "--iree-vm-bytecode-module-strip-source-map=true", + "--iree-opt-strip-assertions=true", + "--iree-vm-target-truncate-unsupported-floats", + ] + ) + elif device == "cuda": + flags.extend( + [ + "--iree-hal-target-backends=cuda", + "--iree-hal-cuda-llvm-target-arch=" + target_triple, + "--iree-vm-bytecode-module-strip-source-map=true", + "--iree-vm-target-truncate-unsupsported-floats", + ] + ) + else: + print("incorrect device: ", device) + + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + extra_args=flags, + ) + with open(f"{safe_name}.vmfb", "wb+") as f: + f.write(flatbuffer_blob) + print("Saved to", safe_name + ".vmfb") + exit() diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py new file mode 100644 index 000000000..cf3c587a1 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -0,0 +1,180 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import re + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import AutoencoderKL + +import safetensors +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument("--run_vmfb", action="store_true") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_file", type=str, default="") +parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + + +class VaeModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + token=hf_auth_token, + ) + + def forward(self, inp): + with torch.no_grad(): + x = self.vae.decode(inp, return_dict=False)[0] + return x + + +def export_vae_model( + vae_model, + hf_model_name, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_file=None, + device=None, + target_triple=None, + max_alloc=None, +): + mapper = {} + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_file + ) + + class CompiledVae(CompiledModule): + params = export_parameters(vae_model) + + def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)): + return jittable(vae_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledVae(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +def run_vae_vmfb_comparison(vae_model, args): + config = ireert.Config(args.device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if args.vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + sys.exit("no vmfb_path provided, required for run_vmfb") + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] + if args.external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=config, + ) + inp = torch.rand(1, 4, 64, 64, dtype=torch.float32) + device_inputs = [ireert.asdevicearray(config.device, inp)] + + # Turbine output + ModuleCompiled = ctx.modules.compiled_vae + turbine_output = ModuleCompiled["main"](*device_inputs) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + # Torch output + torch_output = vae_model.forward(inp) + torch_output = torch_output.detach().cpu().numpy() + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + + err = utils.largest_error(torch_output, turbine_output) + print("LARGEST ERROR:", err) + assert err < 9e-5 + + +if __name__ == "__main__": + args = parser.parse_args() + vae_model = VaeModel( + args.hf_model_name, + args.hf_auth_token, + ) + if args.run_vmfb: + run_vae_vmfb_comparison(vae_model, args) + else: + mod_str = export_vae_model( + vae_model, + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/setup.py b/python/turbine_models/setup.py index 33aa3fdb5..4fbc0d256 100644 --- a/python/turbine_models/setup.py +++ b/python/turbine_models/setup.py @@ -63,5 +63,7 @@ def load_version_info(): "protobuf", "sentencepiece", "transformers", + "accelerate", + "diffusers==0.10.2", ], ) diff --git a/python/turbine_models/tests/llama_test.py b/python/turbine_models/tests/llama_test.py index df44a0e70..ec75f8656 100644 --- a/python/turbine_models/tests/llama_test.py +++ b/python/turbine_models/tests/llama_test.py @@ -21,6 +21,7 @@ def testExportTransformerModel(self): "medllama2_f32.safetensors", None, "f32", + "cpu", ) os.remove("medllama2_f32.safetensors") diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py new file mode 100644 index 000000000..c3834840d --- /dev/null +++ b/python/turbine_models/tests/sd_test.py @@ -0,0 +1,101 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import logging +from turbine_models.custom_models.sd_inference import clip, unet, vae +import unittest +import os + + +arguments = { + "hf_auth_token": None, + "hf_model_name": "CompVis/stable-diffusion-v1-4", + "run_vmfb": True, + "compile_to": None, + "external_weight_file": "", + "vmfb_path": "", + "external_weights": None, + "device": "local-task", + "iree_target_triple": "", + "vulkan_max_allocation": "4294967296", +} + + +unet_model = unet.UnetModel( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, +) + +vae_model = vae.VaeModel( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, +) + + +class StableDiffusionTest(unittest.TestCase): + def testExportClipModel(self): + with self.assertRaises(SystemExit) as cm: + clip.export_clip_model( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_clip.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_file"] = "stable_diffusion_v1_4_clip.safetensors" + namespace = argparse.Namespace(**arguments) + clip.run_clip_vmfb_comparison(namespace) + os.remove("stable_diffusion_v1_4_clip.safetensors") + os.remove("stable_diffusion_v1_4.vmfb") + + def testExportUnetModel(self): + with self.assertRaises(SystemExit) as cm: + unet.export_unet_model( + unet_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_unet.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_file"] = "stable_diffusion_v1_4_unet.safetensors" + namespace = argparse.Namespace(**arguments) + unet.run_unet_vmfb_comparison(unet_model, namespace) + os.remove("stable_diffusion_v1_4_unet.safetensors") + os.remove("stable_diffusion_v1_4.vmfb") + + def testExportVaeModel(self): + with self.assertRaises(SystemExit) as cm: + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_vae.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_file"] = "stable_diffusion_v1_4_vae.safetensors" + namespace = argparse.Namespace(**arguments) + vae.run_vae_vmfb_comparison(vae_model, namespace) + os.remove("stable_diffusion_v1_4_vae.safetensors") + os.remove("stable_diffusion_v1_4.vmfb") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/turbine-models-requirements.txt b/turbine-models-requirements.txt index 606fea685..c93f5dcec 100644 --- a/turbine-models-requirements.txt +++ b/turbine-models-requirements.txt @@ -2,4 +2,6 @@ protobuf sentencepiece shark_turbine transformers +accelerate +diffusers==0.10.2 brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b From f1aa87924376a9b520bb1cefd22e4a46e8c1a749 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:21:13 +0000 Subject: [PATCH 06/28] fold run_vmfb_comparison into python/turbine_models/tests and actually fail tests on comparison fail --- .../custom_models/stateless_llama.py | 119 ++----------- .../tests/stateless_llama_test.py | 47 +++-- .../turbine_models/tests/vmfb_comparison.py | 163 ++++++++++++++++++ 3 files changed, 213 insertions(+), 116 deletions(-) create mode 100644 python/turbine_models/tests/vmfb_comparison.py diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 2f21d5da5..48239fb8f 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -19,7 +19,6 @@ import argparse parser = argparse.ArgumentParser() -parser.add_argument("--run_vmfb", action="store_true") parser.add_argument( "--hf_auth_token", type=str, help="The Hugging Face auth token, required" ) @@ -280,107 +279,25 @@ def forward(token0: torch.Tensor, *state0_flat): print("saved to ", safe_name + ".vmfb") return module_str, tokenizer +# if you're looking for run_vmfb_comparison, it's now in python/turbine_models/tests/vmfb_comparison.py -def run_vmfb_comparison(args): - config = ireert.Config("local-task") - - if args.external_weight_file: - index = ireert.ParameterIndex() - index.load(args.external_weight_file) - - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - if args.vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - sys.exit("no vmfb_path provided, required for run_vmfb") - - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device), - ] - if args.external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) - - ctx = ireert.SystemContext( - vm_modules=vm_modules, - config=config, - ) - tokenizer = AutoTokenizer.from_pretrained( - args.hf_model_name, - use_fast=False, - token=args.hf_auth_token, - ) - initial_input = tokenizer(prompt, return_tensors="pt") - example_input_id = initial_input.input_ids - device_inputs = [ireert.asdevicearray(config.device, example_input_id)] - - ModuleCompiled = ctx.modules.state_update - results = ModuleCompiled["run_initialize"](*device_inputs) - - def format_out(results): - return torch.tensor(results.to_host()[0][0]) +if __name__ == "__main__": + args = parser.parse_args() - model = AutoModelForCausalLM.from_pretrained( + mod_str, _ = export_transformer_model( args.hf_model_name, - torch_dtype=torch.float, - use_auth_token=args.hf_auth_token, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.quantization, + args.precision, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, ) - - def get_token_from_logits(logits): - return torch.argmax(logits[:, -1, :], dim=1) - - base_model_results = model.forward(example_input_id) - base_model_token = get_token_from_logits(base_model_results.logits) - bm_pkv = base_model_results.past_key_values - turbine_results = [] - torch_results = [] - turbine_results.append(format_out(results)) - torch_results.append(int(base_model_token)) - while base_model_token != 2: - results = ModuleCompiled["run_forward"](results) - base_model_results = model.forward( - torch.unsqueeze(base_model_token, 0), past_key_values=bm_pkv - ) - base_model_token = get_token_from_logits(base_model_results.logits) - - bm_pkv = base_model_results.past_key_values - # uncomment to see tokens as they are emittd - # print(f"pytorch: {tokenizer.decode(base_model_token)}") - # print(f"turbine: {tokenizer.decode(format_out(results))}") - turbine_results.append(format_out(results)) - torch_results.append(int(base_model_token[0])) - - print("turbine output: ") - print(tokenizer.decode(turbine_results)) - print("\ntorch output: ") - print(tokenizer.decode(torch_results)) - - -if __name__ == "__main__": - args = parser.parse_args() - if args.run_vmfb: - run_vmfb_comparison(args) - else: - mod_str, _ = export_transformer_model( - args.hf_model_name, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.quantization, - args.precision, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to ", safe_name + ".mlir") + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to ", safe_name + ".mlir") diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 4b957afb3..e5ca9a301 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -15,7 +15,7 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): llama.export_transformer_model( - hf_model_name="llSourcell/medllama2_7b", + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, compile_to="vmfb", external_weights="safetensors", @@ -32,26 +32,43 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " ) gen_external_params( - hf_model_name="llSourcell/medllama2_7b", + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", quantization=quantization, weight_path="medllama2_7b_f16_int4.safetensors", hf_auth_token=None, precision=precision, ) - from types import SimpleNamespace - - args = SimpleNamespace() - args.hf_model_name = "llSourcell/medllama2_7b" - args.hf_auth_token = None - args.vmfb_path = "medllama2_7b.vmfb" - args.external_weight_file = "medllama2_7b_f16_int4.safetensors" - args.run_vmfb = True - args.device = "llvm-cpu" - args.precision = precision - args.quantization = quantization - args.iree_target_triple = "host" - llama.run_vmfb_comparison(args) + + # def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): + DEFAULT_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] +""" + + from .vmfb_comparison import run_vmfb_comparison + turbine_str, torch_str = run_vmfb_comparison( + prompt=DEFAULT_PROMPT, + hf_auth_token=None, + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + vmfb_path="medllama2_7b.vmfb", + external_weight_file="medllama2_7b_f16_int4.safetensors", + break_on_eos=True, + ) + + import difflib + # Calculate and print diff + diff = difflib.unified_diff( + torch_str.splitlines(keepends=True), + turbine_str.splitlines(keepends=True), + fromfile='torch_str', + tofile='turbine_str', + lineterm='' + ) + + assert torch_str == turbine_str, "".join(diff) + + + if __name__ == "__main__": diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py new file mode 100644 index 000000000..256a5e1b9 --- /dev/null +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -0,0 +1,163 @@ +import os +import sys +import re + +from typing import Tuple + +os.environ["TORCH_LOGS"] = "dynamic" +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from torch.utils import _pytree as pytree +from shark_turbine.aot import * +from iree.compiler.ir import Context +from iree import runtime as ireert + +from turbine_models.custom_models import remap_gguf +import safetensors + +from tqdm import tqdm + +BATCH_SIZE = 1 +MAX_STEP_SEQ = 4095 + +def torch_token_generator(prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False): + tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token) + + initial_input = tokenizer(prompt, return_tensors="pt") + input_ids = initial_input.input_ids + past_key_values = None + + while True: + model_results = model.forward(input_ids, past_key_values=past_key_values) + logits = model_results.logits + next_token_id = torch.argmax(logits[:, -1, :], dim=1) + past_key_values = model_results.past_key_values + + yield next_token_id + input_ids = next_token_id.unsqueeze(0) # Prepare for the next iteration + + if next_token_id.item() == tokenizer.eos_token_id and break_on_eos: + break + +def turbine_token_generator( + prompt: str, + hf_model_name: str, + vmfb_path: str = None, + external_weight_file: str = None, + hf_auth_token: str = None, + break_on_eos: bool = False +) -> torch.Tensor: + """ + A generator function for turbine model inference. + + :param prompt: The input prompt for the model. + :param hf_model_name: The name of the Hugging Face model. + :param vmfb_path: Path to the .vmfb model file. + :param external_weight_file: Path to the external weight file (optional). + :param hf_auth_token: Hugging Face authorization token (optional). + :param break_on_eos: Whether to break the loop on end-of-sentence token. + :return: Yields a tensor representing the generated token. + """ + + # Create the config for the IREE runtime environment + config = ireert.Config("local-task") + + # Load the external weight file if provided + if external_weight_file: + index = ireert.ParameterIndex() + index.load(external_weight_file) + + # Ensure model name is in a safe format + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + + # Load the .vmfb model file + if vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") + + # Prepare the modules for the IREE runtime context + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device) + ] + + # Include parameter module if external weight file is used + if external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + # Create the system context with the given configuration and modules + ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) + + # Initialize the tokenizer + tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + + # Convert the prompt to input tensor + initial_input = tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + device_inputs = [ireert.asdevicearray(config.device, example_input_id)] + + # Get the compiled module + ModuleCompiled = ctx.modules.state_update + results = ModuleCompiled["run_initialize"](*device_inputs) + + def format_out(results): + # Convert the output to a PyTorch tensor + return torch.tensor(results.to_host()[0][0]) + + # Token generation loop + while True: + next_token_tensor = format_out(results) + yield next_token_tensor.item() # Yield the scalar value of the tensor + + # Run the next step of the model + results = ModuleCompiled["run_forward"](results) + + # Check for the end-of-sentence token + if next_token_tensor.item() == tokenizer.eos_token_id and break_on_eos: + break + + +def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): + # Initialize generators with the prompt and specific arguments + print("Using prompt:") + print(prompt) + torch_gen = torch_token_generator( + prompt=prompt, + hf_auth_token=hf_auth_token, + hf_model_name=hf_model_name, + break_on_eos=break_on_eos + ) + + print("Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) + del torch_gen + + # Run turbine until an equal number of tokens has been generated + print("Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + turbine_gen = turbine_token_generator( + prompt=prompt, + hf_model_name=hf_model_name, + vmfb_path=vmfb_path, + external_weight_file=external_weight_file, + hf_auth_token=hf_auth_token, + break_on_eos=break_on_eos + ) + turbine_tokens = [] + for _ in tqdm(range(len(torch_tokens)), desc="Generating Turbine tokens"): + token = next(turbine_gen) + turbine_tokens.append(token) + del turbine_gen + + # Decode and print the outputs + tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + turbine_str = (tokenizer.decode(torch.tensor(turbine_tokens).numpy())) + torch_str = (tokenizer.decode(torch.tensor(torch_tokens).numpy())) + return turbine_str, torch_str \ No newline at end of file From 93952485a11b39235e9d781b80b1fba4591b2ef8 Mon Sep 17 00:00:00 2001 From: IanNod <45800100+IanNod@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:35:25 -0800 Subject: [PATCH 07/28] Adds tests for gen_external_params.py quantize function (#225) --- .../tests/gen_external_params_test.py | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 python/turbine_models/tests/gen_external_params_test.py diff --git a/python/turbine_models/tests/gen_external_params_test.py b/python/turbine_models/tests/gen_external_params_test.py new file mode 100644 index 000000000..503a2af6a --- /dev/null +++ b/python/turbine_models/tests/gen_external_params_test.py @@ -0,0 +1,132 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from turbine_models.gen_external_params.gen_external_params import quantize +from turbine_models.model_builder import HFTransformerBuilder +from transformers import AutoTokenizer, AutoModelForCausalLM +import unittest +import os +import torch +import pytest + + +class ExternalParamsTest(unittest.TestCase): + def testQuantizeF32(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model, "", torch.float32) + for weight in quant_weights: + self.assertNotIn("weight_zp", weight) + self.assertNotIn("weight_scale", weight) + assert quant_weights[weight].dtype in [torch.float32] + + def testQuantizeF32I8(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model, "int8", torch.float32) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float32] + else: + assert quant_weights[weight].dtype in [torch.float32] + + def testQuantizeF32I4(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model, "int4", torch.float32) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == 2 * quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float32] + else: + assert quant_weights[weight].dtype in [torch.float32] + + def testQuantizeF16(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model.half(), "", torch.float16) + for weight in quant_weights: + self.assertNotIn("weight_zp", weight) + self.assertNotIn("weight_scale", weight) + assert quant_weights[weight].dtype in [torch.float16] + + @pytest.mark.xfail(reason="brevitas issue with f16 int8 quanttensor") + def testQuantizeF16I8(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model.half(), "int8", torch.float16) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float16] + else: + assert quant_weights[weight].dtype in [torch.float16] + + def testQuantizeF16I4(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model.half(), "int4", torch.float16) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == 2 * quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float16] + else: + assert quant_weights[weight].dtype in [torch.float16] + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From f6c8008136cf0484d13682ab446e2e7288a32e8c Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:37:58 +0000 Subject: [PATCH 08/28] remove python-fire dependency --- .../gen_external_params.py | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index de684e9f5..ba5614ec0 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -3,24 +3,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM import torch -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name ID", - default="meta-llama/Llama-2-7b-chat-hf", -) -parser.add_argument("--quantization", type=str, default="int4") -parser.add_argument("--weight_path", type=str, default="") -parser.add_argument( - "--hf_auth_token", type=str, help="The HF auth token required for some models" -) -parser.add_argument( - "--precision", type=str, default="f16", help="Data type of model [f16, f32]" -) - def quantize(model, quantization, dtype): accumulates = dtype @@ -118,6 +100,33 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) + if __name__ == "__main__": - import fire - fire.Fire(gen_external_params) \ No newline at end of file + import argparse + import sys + parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") + + parser.add_argument("--hf_model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf", + help="The Hugging Face model name ID.") + parser.add_argument("--quantization", type=str, default="int4", + choices=["int4", "int8"], help="Type of quantization to apply.") + parser.add_argument("--weight_path", type=str, default="", + help="Path to save the quantized model weights.") + parser.add_argument("--hf_auth_token", type=str, default=None, + help="The Hugging Face auth token required for some models.") + parser.add_argument("--precision", type=str, default="f16", + choices=["f16", "f32"], help="Data type of model.") + + args = parser.parse_args() + + try: + gen_external_params( + hf_model_name=args.hf_model_name, + quantization=args.quantization, + weight_path=args.weight_path, + hf_auth_token=args.hf_auth_token, + precision=args.precision + ) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) \ No newline at end of file From 612556a7e5c12c6a23e00d1dfe3f2e2e8448e209 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:42:06 +0000 Subject: [PATCH 09/28] remove unnecessary vulcan max_alloc and rename for consistency between argparse and function params --- python/turbine_models/custom_models/stateless_llama.py | 4 ++-- python/turbine_models/tests/stateless_llama_test.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 48239fb8f..69796e549 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -92,7 +92,7 @@ def export_transformer_model( precision=None, device=None, target_triple=None, - max_alloc=None, + vulkan_max_allocation=None, ): state_schema = pytree.treespec_loads(json_schema) @@ -243,7 +243,7 @@ def forward(token0: torch.Tensor, *state0_flat): flags.extend( [ "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" + max_alloc, + "--iree-stream-resource-max-allocation-size=" + vulkan_max_allocation, ] ) elif device == "rocm": diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index e5ca9a301..b86135e84 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -24,7 +24,6 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " precision=precision, device="llvm-cpu", target_triple="host", - max_alloc="4294967296", ) from turbine_models.gen_external_params.gen_external_params import ( From 200dbe52dcbaad8eb18514b8562b8bf19419640d Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:42:38 +0000 Subject: [PATCH 10/28] black --- .../custom_models/stateless_llama.py | 4 +- .../gen_external_params.py | 71 +++++++++++++------ .../tests/stateless_llama_test.py | 16 ++--- .../turbine_models/tests/vmfb_comparison.py | 66 +++++++++++------ 4 files changed, 102 insertions(+), 55 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 69796e549..337f94122 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -243,7 +243,8 @@ def forward(token0: torch.Tensor, *state0_flat): flags.extend( [ "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" + vulkan_max_allocation, + "--iree-stream-resource-max-allocation-size=" + + vulkan_max_allocation, ] ) elif device == "rocm": @@ -279,6 +280,7 @@ def forward(token0: torch.Tensor, *state0_flat): print("saved to ", safe_name + ".vmfb") return module_str, tokenizer + # if you're looking for run_vmfb_comparison, it's now in python/turbine_models/tests/vmfb_comparison.py if __name__ == "__main__": diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index ba5614ec0..5bf711994 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -54,11 +54,14 @@ def forward(self, x): all_weights.update(int_weights) return all_weights -def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", - quantization:str = "int4", - weight_path:str = "", - hf_auth_token:str = None, - precision:str = "f16"): + +def gen_external_params( + hf_model_name: str = "meta-llama/Llama-2-7b-chat-hf", + quantization: str = "int4", + weight_path: str = "", + hf_auth_token: str = None, + precision: str = "f16", +): """ Main function to run the model quantization and saving process. @@ -90,13 +93,12 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", if weight_path == "": save_path = hf_model_name.split("/")[-1].strip() save_path = re.sub("-", "_", save_path) - save_path = ( - save_path + "_" + precision + "_" + quantization + ".safetensors" - ) + save_path = save_path + "_" + precision + "_" + quantization + ".safetensors" else: save_path = weight_path import safetensors + safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) @@ -104,18 +106,43 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", if __name__ == "__main__": import argparse import sys - parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") - - parser.add_argument("--hf_model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf", - help="The Hugging Face model name ID.") - parser.add_argument("--quantization", type=str, default="int4", - choices=["int4", "int8"], help="Type of quantization to apply.") - parser.add_argument("--weight_path", type=str, default="", - help="Path to save the quantized model weights.") - parser.add_argument("--hf_auth_token", type=str, default=None, - help="The Hugging Face auth token required for some models.") - parser.add_argument("--precision", type=str, default="f16", - choices=["f16", "f32"], help="Data type of model.") + + parser = argparse.ArgumentParser( + description="Quantize and save Hugging Face models." + ) + + parser.add_argument( + "--hf_model_name", + type=str, + default="meta-llama/Llama-2-7b-chat-hf", + help="The Hugging Face model name ID.", + ) + parser.add_argument( + "--quantization", + type=str, + default="int4", + choices=["int4", "int8"], + help="Type of quantization to apply.", + ) + parser.add_argument( + "--weight_path", + type=str, + default="", + help="Path to save the quantized model weights.", + ) + parser.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="The Hugging Face auth token required for some models.", + ) + parser.add_argument( + "--precision", + type=str, + default="f16", + choices=["f16", "f32"], + help="Data type of model.", + ) args = parser.parse_args() @@ -125,8 +152,8 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", quantization=args.quantization, weight_path=args.weight_path, hf_auth_token=args.hf_auth_token, - precision=args.precision + precision=args.precision, ) except Exception as e: print(f"Error: {e}", file=sys.stderr) - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index b86135e84..2ae38fc25 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -38,13 +38,13 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " precision=precision, ) - # def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): DEFAULT_PROMPT = """[INST] <> Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] """ from .vmfb_comparison import run_vmfb_comparison + turbine_str, torch_str = run_vmfb_comparison( prompt=DEFAULT_PROMPT, hf_auth_token=None, @@ -55,19 +55,17 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " ) import difflib - # Calculate and print diff + + # Calculate and print diff diff = difflib.unified_diff( torch_str.splitlines(keepends=True), turbine_str.splitlines(keepends=True), - fromfile='torch_str', - tofile='turbine_str', - lineterm='' + fromfile="torch_str", + tofile="turbine_str", + lineterm="", ) - - assert torch_str == turbine_str, "".join(diff) - - + assert torch_str == turbine_str, "".join(diff) if __name__ == "__main__": diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index 256a5e1b9..4d4dcc15b 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -20,9 +20,16 @@ BATCH_SIZE = 1 MAX_STEP_SEQ = 4095 -def torch_token_generator(prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False): - tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) - model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token) + +def torch_token_generator( + prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False +): + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + model = AutoModelForCausalLM.from_pretrained( + hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token + ) initial_input = tokenizer(prompt, return_tensors="pt") input_ids = initial_input.input_ids @@ -40,13 +47,14 @@ def torch_token_generator(prompt, hf_model_name: str, hf_auth_token: str, break_ if next_token_id.item() == tokenizer.eos_token_id and break_on_eos: break + def turbine_token_generator( - prompt: str, - hf_model_name: str, - vmfb_path: str = None, - external_weight_file: str = None, - hf_auth_token: str = None, - break_on_eos: bool = False + prompt: str, + hf_model_name: str, + vmfb_path: str = None, + external_weight_file: str = None, + hf_auth_token: str = None, + break_on_eos: bool = False, ) -> torch.Tensor: """ A generator function for turbine model inference. @@ -81,10 +89,7 @@ def turbine_token_generator( raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") # Prepare the modules for the IREE runtime context - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device) - ] + vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)] # Include parameter module if external weight file is used if external_weight_file: @@ -97,7 +102,9 @@ def turbine_token_generator( ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) # Initialize the tokenizer - tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) # Convert the prompt to input tensor initial_input = tokenizer(prompt, return_tensors="pt") @@ -125,7 +132,14 @@ def format_out(results): break -def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): +def run_vmfb_comparison( + prompt, + hf_auth_token, + hf_model_name, + vmfb_path, + external_weight_file, + break_on_eos=True, +): # Initialize generators with the prompt and specific arguments print("Using prompt:") print(prompt) @@ -133,22 +147,26 @@ def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, externa prompt=prompt, hf_auth_token=hf_auth_token, hf_model_name=hf_model_name, - break_on_eos=break_on_eos + break_on_eos=break_on_eos, ) - print("Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + print( + "Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while." + ) torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) del torch_gen # Run turbine until an equal number of tokens has been generated - print("Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + print( + "Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while." + ) turbine_gen = turbine_token_generator( prompt=prompt, hf_model_name=hf_model_name, vmfb_path=vmfb_path, external_weight_file=external_weight_file, hf_auth_token=hf_auth_token, - break_on_eos=break_on_eos + break_on_eos=break_on_eos, ) turbine_tokens = [] for _ in tqdm(range(len(torch_tokens)), desc="Generating Turbine tokens"): @@ -157,7 +175,9 @@ def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, externa del turbine_gen # Decode and print the outputs - tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) - turbine_str = (tokenizer.decode(torch.tensor(turbine_tokens).numpy())) - torch_str = (tokenizer.decode(torch.tensor(torch_tokens).numpy())) - return turbine_str, torch_str \ No newline at end of file + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + turbine_str = tokenizer.decode(torch.tensor(turbine_tokens).numpy()) + torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) + return turbine_str, torch_str From 4f58d78c2698d8ddbfa7f3b18607938d6f66d377 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 19:10:32 +0000 Subject: [PATCH 11/28] resolve merge conflicts --- .github/workflows/test_models.yml | 2 +- .../custom_models/stateless_llama.py | 10 ++-- .../gen_external_params.py | 45 ++++++++++----- python/turbine_models/tests/conftest.py | 18 ++++++ .../tests/stateless_llama_test.py | 57 +++++++++++++++++++ 5 files changed, 112 insertions(+), 20 deletions(-) create mode 100644 python/turbine_models/tests/conftest.py create mode 100644 python/turbine_models/tests/stateless_llama_test.py diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 3227c2769..36da25eea 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -1,4 +1,4 @@ -name: Test +name: turbine_models test on: workflow_dispatch: diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 61727d673..59e09e1e4 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -48,12 +48,12 @@ "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +parser.add_argument("--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm") # TODO: Bring in detection for target triple parser.add_argument( "--iree_target_triple", type=str, - default="", + default="host", help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") @@ -236,7 +236,7 @@ def forward(token0: torch.Tensor, *state0_flat): "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-opt-const-expr-hoisting=False", ] - if device == "cpu": + if device == "cpu" or device == "llvm-cpu": flags.append("--iree-llvmcpu-enable-ukernels=all") device = "llvm-cpu" elif device == "vulkan": @@ -277,11 +277,11 @@ def forward(token0: torch.Tensor, *state0_flat): with open(f"{safe_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) print("saved to ", safe_name + ".vmfb") - exit() + return module_str, tokenizer def run_vmfb_comparison(args): - config = ireert.Config(args.device) + config = ireert.Config("local-task") if args.external_weight_file: index = ireert.ParameterIndex() diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 7e993de72..de684e9f5 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -72,35 +72,52 @@ def forward(self, x): all_weights.update(int_weights) return all_weights +def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", + quantization:str = "int4", + weight_path:str = "", + hf_auth_token:str = None, + precision:str = "f16"): + """ + Main function to run the model quantization and saving process. -if __name__ == "__main__": - args = parser.parse_args() + :param hf_model_name: The Hugging Face model name ID. + :param quantization: Type of quantization to apply ('int4' or 'int8'). + :param weight_path: Path to save the quantized model weights. + :param hf_auth_token: The Hugging Face auth token required for some models. + :param precision: Data type of model ('f16' or 'f32'). + """ model_builder = HFTransformerBuilder( example_input=None, - hf_id=args.hf_model_name, + hf_id=hf_model_name, auto_model=AutoModelForCausalLM, - hf_auth_token=args.hf_auth_token, + hf_auth_token=hf_auth_token, ) model_builder.build_model() - if args.precision == "f16": + + if precision == "f16": model = model_builder.model.half() dtype = torch.float16 - elif args.precision == "f32": + elif precision == "f32": model = model_builder.model dtype = torch.float32 else: - sys.exit("invalid precision, f16 or f32 supported") - quant_weights = quantize(model, args.quantization, dtype) - # TODO: Add more than just safetensor support - import safetensors + sys.exit("Invalid precision, f16 or f32 supported") + + quant_weights = quantize(model, quantization, dtype) - if args.weight_path == "": - save_path = args.hf_model_name.split("/")[-1].strip() + if weight_path == "": + save_path = hf_model_name.split("/")[-1].strip() save_path = re.sub("-", "_", save_path) save_path = ( - save_path + "_" + args.precision + "_" + args.quantization + ".safetensors" + save_path + "_" + precision + "_" + quantization + ".safetensors" ) else: - save_path = args.weight_path + save_path = weight_path + + import safetensors safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) + +if __name__ == "__main__": + import fire + fire.Fire(gen_external_params) \ No newline at end of file diff --git a/python/turbine_models/tests/conftest.py b/python/turbine_models/tests/conftest.py new file mode 100644 index 000000000..ae6071971 --- /dev/null +++ b/python/turbine_models/tests/conftest.py @@ -0,0 +1,18 @@ +def pytest_addoption(parser): + parser.addoption("--all", action="store_true", help="run all combinations") + + +def pytest_generate_tests(metafunc): + if "quantization" in metafunc.fixturenames: + if metafunc.config.getoption("all"): + quantizations = ["int4", None] + else: + quantizations = ["int4"] + metafunc.parametrize("quantization", quantizations) + + if "precision" in metafunc.fixturenames: + if metafunc.config.getoption("all"): + precisions = ["f16", "f32"] + else: + precisions = ["f16"] + metafunc.parametrize("precision", precisions) \ No newline at end of file diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py new file mode 100644 index 000000000..0aaeae6ff --- /dev/null +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import turbine_models.custom_models.stateless_llama as llama +import unittest +import os +import pytest + +from typing import Literal + +def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): + + llama.export_transformer_model( + hf_model_name="llSourcell/medllama2_7b", + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + # external_weight_file="medllama2_7b_f16_int4.safetensors", Do not export weights because this doesn't get quantized + quantization=quantization, + precision=precision, + device="llvm-cpu", + target_triple="host", + max_alloc = "4294967296" + ) + + from turbine_models.gen_external_params.gen_external_params import gen_external_params + gen_external_params( + hf_model_name="llSourcell/medllama2_7b", + quantization=quantization, + weight_path="medllama2_7b_f16_int4.safetensors", + hf_auth_token=None, + precision=precision + ) + + from types import SimpleNamespace + args = SimpleNamespace() + args.hf_model_name = "llSourcell/medllama2_7b" + args.hf_auth_token = None + args.vmfb_path = "medllama2_7b.vmfb" + args.external_weight_file = "medllama2_7b_f16_int4.safetensors" + args.run_vmfb = True + args.device="llvm-cpu" + args.precision = precision + args.quantization = quantization + args.iree_target_triple="host" + llama.run_vmfb_comparison(args) + + + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From 3ab45d186277c3451b50f23c28ae389f07be859a Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 01:12:34 +0000 Subject: [PATCH 12/28] adjust naming of tests to look clearer --- .github/workflows/test_models.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 36da25eea..e353273e2 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -1,4 +1,4 @@ -name: turbine_models test +name: Test Turbine Models on: workflow_dispatch: @@ -8,7 +8,7 @@ on: - main jobs: - test: + test-turbine-models: strategy: matrix: version: [3.11] From e552a1afe199b92815b87f123b4940a0f9f83dbf Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 01:13:28 +0000 Subject: [PATCH 13/28] fix formatting with black --- .../custom_models/stateless_llama.py | 4 +++- python/turbine_models/tests/conftest.py | 4 ++-- .../tests/stateless_llama_test.py | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 59e09e1e4..7fc20b334 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -48,7 +48,9 @@ "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) -parser.add_argument("--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm") +parser.add_argument( + "--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm" +) # TODO: Bring in detection for target triple parser.add_argument( "--iree_target_triple", diff --git a/python/turbine_models/tests/conftest.py b/python/turbine_models/tests/conftest.py index ae6071971..1b1fe1f67 100644 --- a/python/turbine_models/tests/conftest.py +++ b/python/turbine_models/tests/conftest.py @@ -9,10 +9,10 @@ def pytest_generate_tests(metafunc): else: quantizations = ["int4"] metafunc.parametrize("quantization", quantizations) - + if "precision" in metafunc.fixturenames: if metafunc.config.getoption("all"): precisions = ["f16", "f32"] else: precisions = ["f16"] - metafunc.parametrize("precision", precisions) \ No newline at end of file + metafunc.parametrize("precision", precisions) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 0aaeae6ff..4b957afb3 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -12,8 +12,8 @@ from typing import Literal + def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): - llama.export_transformer_model( hf_model_name="llSourcell/medllama2_7b", hf_auth_token=None, @@ -24,32 +24,34 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " precision=precision, device="llvm-cpu", target_triple="host", - max_alloc = "4294967296" + max_alloc="4294967296", + ) + + from turbine_models.gen_external_params.gen_external_params import ( + gen_external_params, ) - from turbine_models.gen_external_params.gen_external_params import gen_external_params gen_external_params( hf_model_name="llSourcell/medllama2_7b", quantization=quantization, weight_path="medllama2_7b_f16_int4.safetensors", hf_auth_token=None, - precision=precision + precision=precision, ) from types import SimpleNamespace + args = SimpleNamespace() args.hf_model_name = "llSourcell/medllama2_7b" args.hf_auth_token = None args.vmfb_path = "medllama2_7b.vmfb" args.external_weight_file = "medllama2_7b_f16_int4.safetensors" args.run_vmfb = True - args.device="llvm-cpu" + args.device = "llvm-cpu" args.precision = precision args.quantization = quantization - args.iree_target_triple="host" + args.iree_target_triple = "host" llama.run_vmfb_comparison(args) - - if __name__ == "__main__": From 23fd558c918255c7238a53545d3c49c2ae7cf943 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:21:13 +0000 Subject: [PATCH 14/28] fold run_vmfb_comparison into python/turbine_models/tests and actually fail tests on comparison fail --- .../custom_models/stateless_llama.py | 119 ++----------- .../tests/stateless_llama_test.py | 47 +++-- .../turbine_models/tests/vmfb_comparison.py | 163 ++++++++++++++++++ 3 files changed, 213 insertions(+), 116 deletions(-) create mode 100644 python/turbine_models/tests/vmfb_comparison.py diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 7fc20b334..4538b0559 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -19,7 +19,6 @@ import argparse parser = argparse.ArgumentParser() -parser.add_argument("--run_vmfb", action="store_true") parser.add_argument( "--hf_auth_token", type=str, help="The Hugging Face auth token, required" ) @@ -281,107 +280,25 @@ def forward(token0: torch.Tensor, *state0_flat): print("saved to ", safe_name + ".vmfb") return module_str, tokenizer +# if you're looking for run_vmfb_comparison, it's now in python/turbine_models/tests/vmfb_comparison.py -def run_vmfb_comparison(args): - config = ireert.Config("local-task") - - if args.external_weight_file: - index = ireert.ParameterIndex() - index.load(args.external_weight_file) - - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - if args.vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - sys.exit("no vmfb_path provided, required for run_vmfb") - - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device), - ] - if args.external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) - - ctx = ireert.SystemContext( - vm_modules=vm_modules, - config=config, - ) - tokenizer = AutoTokenizer.from_pretrained( - args.hf_model_name, - use_fast=False, - token=args.hf_auth_token, - ) - initial_input = tokenizer(prompt, return_tensors="pt") - example_input_id = initial_input.input_ids - device_inputs = [ireert.asdevicearray(config.device, example_input_id)] - - ModuleCompiled = ctx.modules.state_update - results = ModuleCompiled["run_initialize"](*device_inputs) - - def format_out(results): - return torch.tensor(results.to_host()[0][0]) +if __name__ == "__main__": + args = parser.parse_args() - model = AutoModelForCausalLM.from_pretrained( + mod_str, _ = export_transformer_model( args.hf_model_name, - torch_dtype=torch.float, - use_auth_token=args.hf_auth_token, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.quantization, + args.precision, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, ) - - def get_token_from_logits(logits): - return torch.argmax(logits[:, -1, :], dim=1) - - base_model_results = model.forward(example_input_id) - base_model_token = get_token_from_logits(base_model_results.logits) - bm_pkv = base_model_results.past_key_values - turbine_results = [] - torch_results = [] - turbine_results.append(format_out(results)) - torch_results.append(int(base_model_token)) - while base_model_token != 2: - results = ModuleCompiled["run_forward"](results) - base_model_results = model.forward( - torch.unsqueeze(base_model_token, 0), past_key_values=bm_pkv - ) - base_model_token = get_token_from_logits(base_model_results.logits) - - bm_pkv = base_model_results.past_key_values - # uncomment to see tokens as they are emittd - # print(f"pytorch: {tokenizer.decode(base_model_token)}") - # print(f"turbine: {tokenizer.decode(format_out(results))}") - turbine_results.append(format_out(results)) - torch_results.append(int(base_model_token[0])) - - print("turbine output: ") - print(tokenizer.decode(turbine_results)) - print("\ntorch output: ") - print(tokenizer.decode(torch_results)) - - -if __name__ == "__main__": - args = parser.parse_args() - if args.run_vmfb: - run_vmfb_comparison(args) - else: - mod_str, _ = export_transformer_model( - args.hf_model_name, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.quantization, - args.precision, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to ", safe_name + ".mlir") + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to ", safe_name + ".mlir") diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 4b957afb3..e5ca9a301 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -15,7 +15,7 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): llama.export_transformer_model( - hf_model_name="llSourcell/medllama2_7b", + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, compile_to="vmfb", external_weights="safetensors", @@ -32,26 +32,43 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " ) gen_external_params( - hf_model_name="llSourcell/medllama2_7b", + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", quantization=quantization, weight_path="medllama2_7b_f16_int4.safetensors", hf_auth_token=None, precision=precision, ) - from types import SimpleNamespace - - args = SimpleNamespace() - args.hf_model_name = "llSourcell/medllama2_7b" - args.hf_auth_token = None - args.vmfb_path = "medllama2_7b.vmfb" - args.external_weight_file = "medllama2_7b_f16_int4.safetensors" - args.run_vmfb = True - args.device = "llvm-cpu" - args.precision = precision - args.quantization = quantization - args.iree_target_triple = "host" - llama.run_vmfb_comparison(args) + + # def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): + DEFAULT_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] +""" + + from .vmfb_comparison import run_vmfb_comparison + turbine_str, torch_str = run_vmfb_comparison( + prompt=DEFAULT_PROMPT, + hf_auth_token=None, + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + vmfb_path="medllama2_7b.vmfb", + external_weight_file="medllama2_7b_f16_int4.safetensors", + break_on_eos=True, + ) + + import difflib + # Calculate and print diff + diff = difflib.unified_diff( + torch_str.splitlines(keepends=True), + turbine_str.splitlines(keepends=True), + fromfile='torch_str', + tofile='turbine_str', + lineterm='' + ) + + assert torch_str == turbine_str, "".join(diff) + + + if __name__ == "__main__": diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py new file mode 100644 index 000000000..256a5e1b9 --- /dev/null +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -0,0 +1,163 @@ +import os +import sys +import re + +from typing import Tuple + +os.environ["TORCH_LOGS"] = "dynamic" +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from torch.utils import _pytree as pytree +from shark_turbine.aot import * +from iree.compiler.ir import Context +from iree import runtime as ireert + +from turbine_models.custom_models import remap_gguf +import safetensors + +from tqdm import tqdm + +BATCH_SIZE = 1 +MAX_STEP_SEQ = 4095 + +def torch_token_generator(prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False): + tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token) + + initial_input = tokenizer(prompt, return_tensors="pt") + input_ids = initial_input.input_ids + past_key_values = None + + while True: + model_results = model.forward(input_ids, past_key_values=past_key_values) + logits = model_results.logits + next_token_id = torch.argmax(logits[:, -1, :], dim=1) + past_key_values = model_results.past_key_values + + yield next_token_id + input_ids = next_token_id.unsqueeze(0) # Prepare for the next iteration + + if next_token_id.item() == tokenizer.eos_token_id and break_on_eos: + break + +def turbine_token_generator( + prompt: str, + hf_model_name: str, + vmfb_path: str = None, + external_weight_file: str = None, + hf_auth_token: str = None, + break_on_eos: bool = False +) -> torch.Tensor: + """ + A generator function for turbine model inference. + + :param prompt: The input prompt for the model. + :param hf_model_name: The name of the Hugging Face model. + :param vmfb_path: Path to the .vmfb model file. + :param external_weight_file: Path to the external weight file (optional). + :param hf_auth_token: Hugging Face authorization token (optional). + :param break_on_eos: Whether to break the loop on end-of-sentence token. + :return: Yields a tensor representing the generated token. + """ + + # Create the config for the IREE runtime environment + config = ireert.Config("local-task") + + # Load the external weight file if provided + if external_weight_file: + index = ireert.ParameterIndex() + index.load(external_weight_file) + + # Ensure model name is in a safe format + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + + # Load the .vmfb model file + if vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") + + # Prepare the modules for the IREE runtime context + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device) + ] + + # Include parameter module if external weight file is used + if external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + # Create the system context with the given configuration and modules + ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) + + # Initialize the tokenizer + tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + + # Convert the prompt to input tensor + initial_input = tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + device_inputs = [ireert.asdevicearray(config.device, example_input_id)] + + # Get the compiled module + ModuleCompiled = ctx.modules.state_update + results = ModuleCompiled["run_initialize"](*device_inputs) + + def format_out(results): + # Convert the output to a PyTorch tensor + return torch.tensor(results.to_host()[0][0]) + + # Token generation loop + while True: + next_token_tensor = format_out(results) + yield next_token_tensor.item() # Yield the scalar value of the tensor + + # Run the next step of the model + results = ModuleCompiled["run_forward"](results) + + # Check for the end-of-sentence token + if next_token_tensor.item() == tokenizer.eos_token_id and break_on_eos: + break + + +def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): + # Initialize generators with the prompt and specific arguments + print("Using prompt:") + print(prompt) + torch_gen = torch_token_generator( + prompt=prompt, + hf_auth_token=hf_auth_token, + hf_model_name=hf_model_name, + break_on_eos=break_on_eos + ) + + print("Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) + del torch_gen + + # Run turbine until an equal number of tokens has been generated + print("Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + turbine_gen = turbine_token_generator( + prompt=prompt, + hf_model_name=hf_model_name, + vmfb_path=vmfb_path, + external_weight_file=external_weight_file, + hf_auth_token=hf_auth_token, + break_on_eos=break_on_eos + ) + turbine_tokens = [] + for _ in tqdm(range(len(torch_tokens)), desc="Generating Turbine tokens"): + token = next(turbine_gen) + turbine_tokens.append(token) + del turbine_gen + + # Decode and print the outputs + tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + turbine_str = (tokenizer.decode(torch.tensor(turbine_tokens).numpy())) + torch_str = (tokenizer.decode(torch.tensor(torch_tokens).numpy())) + return turbine_str, torch_str \ No newline at end of file From 8b32886220b9bd6199611ae7c45f442c73cabeb0 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:37:58 +0000 Subject: [PATCH 15/28] remove python-fire dependency --- .../gen_external_params.py | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index de684e9f5..ba5614ec0 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -3,24 +3,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM import torch -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_model_name", - type=str, - help="HF model name ID", - default="meta-llama/Llama-2-7b-chat-hf", -) -parser.add_argument("--quantization", type=str, default="int4") -parser.add_argument("--weight_path", type=str, default="") -parser.add_argument( - "--hf_auth_token", type=str, help="The HF auth token required for some models" -) -parser.add_argument( - "--precision", type=str, default="f16", help="Data type of model [f16, f32]" -) - def quantize(model, quantization, dtype): accumulates = dtype @@ -118,6 +100,33 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) + if __name__ == "__main__": - import fire - fire.Fire(gen_external_params) \ No newline at end of file + import argparse + import sys + parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") + + parser.add_argument("--hf_model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf", + help="The Hugging Face model name ID.") + parser.add_argument("--quantization", type=str, default="int4", + choices=["int4", "int8"], help="Type of quantization to apply.") + parser.add_argument("--weight_path", type=str, default="", + help="Path to save the quantized model weights.") + parser.add_argument("--hf_auth_token", type=str, default=None, + help="The Hugging Face auth token required for some models.") + parser.add_argument("--precision", type=str, default="f16", + choices=["f16", "f32"], help="Data type of model.") + + args = parser.parse_args() + + try: + gen_external_params( + hf_model_name=args.hf_model_name, + quantization=args.quantization, + weight_path=args.weight_path, + hf_auth_token=args.hf_auth_token, + precision=args.precision + ) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) \ No newline at end of file From 75a4fc79480980a7d160ba5cfae276d2b7ff3d24 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:42:06 +0000 Subject: [PATCH 16/28] remove unnecessary vulcan max_alloc and rename for consistency between argparse and function params --- python/turbine_models/custom_models/stateless_llama.py | 4 ++-- python/turbine_models/tests/stateless_llama_test.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 4538b0559..a41e4c98c 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -92,7 +92,7 @@ def export_transformer_model( precision=None, device=None, target_triple=None, - max_alloc=None, + vulkan_max_allocation=None, ): state_schema = pytree.treespec_loads(json_schema) @@ -244,7 +244,7 @@ def forward(token0: torch.Tensor, *state0_flat): flags.extend( [ "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" + max_alloc, + "--iree-stream-resource-max-allocation-size=" + vulkan_max_allocation, ] ) elif device == "rocm": diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index e5ca9a301..b86135e84 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -24,7 +24,6 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " precision=precision, device="llvm-cpu", target_triple="host", - max_alloc="4294967296", ) from turbine_models.gen_external_params.gen_external_params import ( From 48e2685524c3190fd86f7285de7e33eb983143f5 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 18:42:38 +0000 Subject: [PATCH 17/28] black --- .../custom_models/stateless_llama.py | 4 +- .../gen_external_params.py | 71 +++++++++++++------ .../tests/stateless_llama_test.py | 16 ++--- .../turbine_models/tests/vmfb_comparison.py | 66 +++++++++++------ 4 files changed, 102 insertions(+), 55 deletions(-) diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index a41e4c98c..7e1ad4341 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -244,7 +244,8 @@ def forward(token0: torch.Tensor, *state0_flat): flags.extend( [ "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" + vulkan_max_allocation, + "--iree-stream-resource-max-allocation-size=" + + vulkan_max_allocation, ] ) elif device == "rocm": @@ -280,6 +281,7 @@ def forward(token0: torch.Tensor, *state0_flat): print("saved to ", safe_name + ".vmfb") return module_str, tokenizer + # if you're looking for run_vmfb_comparison, it's now in python/turbine_models/tests/vmfb_comparison.py if __name__ == "__main__": diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index ba5614ec0..5bf711994 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -54,11 +54,14 @@ def forward(self, x): all_weights.update(int_weights) return all_weights -def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", - quantization:str = "int4", - weight_path:str = "", - hf_auth_token:str = None, - precision:str = "f16"): + +def gen_external_params( + hf_model_name: str = "meta-llama/Llama-2-7b-chat-hf", + quantization: str = "int4", + weight_path: str = "", + hf_auth_token: str = None, + precision: str = "f16", +): """ Main function to run the model quantization and saving process. @@ -90,13 +93,12 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", if weight_path == "": save_path = hf_model_name.split("/")[-1].strip() save_path = re.sub("-", "_", save_path) - save_path = ( - save_path + "_" + precision + "_" + quantization + ".safetensors" - ) + save_path = save_path + "_" + precision + "_" + quantization + ".safetensors" else: save_path = weight_path import safetensors + safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) @@ -104,18 +106,43 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", if __name__ == "__main__": import argparse import sys - parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") - - parser.add_argument("--hf_model_name", type=str, default="meta-llama/Llama-2-7b-chat-hf", - help="The Hugging Face model name ID.") - parser.add_argument("--quantization", type=str, default="int4", - choices=["int4", "int8"], help="Type of quantization to apply.") - parser.add_argument("--weight_path", type=str, default="", - help="Path to save the quantized model weights.") - parser.add_argument("--hf_auth_token", type=str, default=None, - help="The Hugging Face auth token required for some models.") - parser.add_argument("--precision", type=str, default="f16", - choices=["f16", "f32"], help="Data type of model.") + + parser = argparse.ArgumentParser( + description="Quantize and save Hugging Face models." + ) + + parser.add_argument( + "--hf_model_name", + type=str, + default="meta-llama/Llama-2-7b-chat-hf", + help="The Hugging Face model name ID.", + ) + parser.add_argument( + "--quantization", + type=str, + default="int4", + choices=["int4", "int8"], + help="Type of quantization to apply.", + ) + parser.add_argument( + "--weight_path", + type=str, + default="", + help="Path to save the quantized model weights.", + ) + parser.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="The Hugging Face auth token required for some models.", + ) + parser.add_argument( + "--precision", + type=str, + default="f16", + choices=["f16", "f32"], + help="Data type of model.", + ) args = parser.parse_args() @@ -125,8 +152,8 @@ def gen_external_params(hf_model_name:str = "meta-llama/Llama-2-7b-chat-hf", quantization=args.quantization, weight_path=args.weight_path, hf_auth_token=args.hf_auth_token, - precision=args.precision + precision=args.precision, ) except Exception as e: print(f"Error: {e}", file=sys.stderr) - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index b86135e84..2ae38fc25 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -38,13 +38,13 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " precision=precision, ) - # def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): DEFAULT_PROMPT = """[INST] <> Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] """ from .vmfb_comparison import run_vmfb_comparison + turbine_str, torch_str = run_vmfb_comparison( prompt=DEFAULT_PROMPT, hf_auth_token=None, @@ -55,19 +55,17 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " ) import difflib - # Calculate and print diff + + # Calculate and print diff diff = difflib.unified_diff( torch_str.splitlines(keepends=True), turbine_str.splitlines(keepends=True), - fromfile='torch_str', - tofile='turbine_str', - lineterm='' + fromfile="torch_str", + tofile="turbine_str", + lineterm="", ) - - assert torch_str == turbine_str, "".join(diff) - - + assert torch_str == turbine_str, "".join(diff) if __name__ == "__main__": diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index 256a5e1b9..4d4dcc15b 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -20,9 +20,16 @@ BATCH_SIZE = 1 MAX_STEP_SEQ = 4095 -def torch_token_generator(prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False): - tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) - model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token) + +def torch_token_generator( + prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False +): + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + model = AutoModelForCausalLM.from_pretrained( + hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token + ) initial_input = tokenizer(prompt, return_tensors="pt") input_ids = initial_input.input_ids @@ -40,13 +47,14 @@ def torch_token_generator(prompt, hf_model_name: str, hf_auth_token: str, break_ if next_token_id.item() == tokenizer.eos_token_id and break_on_eos: break + def turbine_token_generator( - prompt: str, - hf_model_name: str, - vmfb_path: str = None, - external_weight_file: str = None, - hf_auth_token: str = None, - break_on_eos: bool = False + prompt: str, + hf_model_name: str, + vmfb_path: str = None, + external_weight_file: str = None, + hf_auth_token: str = None, + break_on_eos: bool = False, ) -> torch.Tensor: """ A generator function for turbine model inference. @@ -81,10 +89,7 @@ def turbine_token_generator( raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") # Prepare the modules for the IREE runtime context - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device) - ] + vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)] # Include parameter module if external weight file is used if external_weight_file: @@ -97,7 +102,9 @@ def turbine_token_generator( ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) # Initialize the tokenizer - tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) # Convert the prompt to input tensor initial_input = tokenizer(prompt, return_tensors="pt") @@ -125,7 +132,14 @@ def format_out(results): break -def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): +def run_vmfb_comparison( + prompt, + hf_auth_token, + hf_model_name, + vmfb_path, + external_weight_file, + break_on_eos=True, +): # Initialize generators with the prompt and specific arguments print("Using prompt:") print(prompt) @@ -133,22 +147,26 @@ def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, externa prompt=prompt, hf_auth_token=hf_auth_token, hf_model_name=hf_model_name, - break_on_eos=break_on_eos + break_on_eos=break_on_eos, ) - print("Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + print( + "Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while." + ) torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) del torch_gen # Run turbine until an equal number of tokens has been generated - print("Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while.") + print( + "Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while." + ) turbine_gen = turbine_token_generator( prompt=prompt, hf_model_name=hf_model_name, vmfb_path=vmfb_path, external_weight_file=external_weight_file, hf_auth_token=hf_auth_token, - break_on_eos=break_on_eos + break_on_eos=break_on_eos, ) turbine_tokens = [] for _ in tqdm(range(len(torch_tokens)), desc="Generating Turbine tokens"): @@ -157,7 +175,9 @@ def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, externa del turbine_gen # Decode and print the outputs - tokenizer = AutoTokenizer.from_pretrained(hf_model_name, use_fast=False, use_auth_token=hf_auth_token) - turbine_str = (tokenizer.decode(torch.tensor(turbine_tokens).numpy())) - torch_str = (tokenizer.decode(torch.tensor(torch_tokens).numpy())) - return turbine_str, torch_str \ No newline at end of file + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + turbine_str = tokenizer.decode(torch.tensor(turbine_tokens).numpy()) + torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) + return turbine_str, torch_str From 476b60a0030efdeca7dbcc3f4c1382222815c014 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Thu, 7 Dec 2023 22:48:17 +0000 Subject: [PATCH 18/28] fix discrepancy between vmfb and torch due to one being quantized f16 and the other unquantized f32 --- python/turbine_models/tests/conftest.py | 6 +- .../tests/stateless_llama_test.py | 65 +++++++++++++--- .../turbine_models/tests/vmfb_comparison.py | 74 ++++++++++++++----- 3 files changed, 114 insertions(+), 31 deletions(-) diff --git a/python/turbine_models/tests/conftest.py b/python/turbine_models/tests/conftest.py index 1b1fe1f67..07da4cd47 100644 --- a/python/turbine_models/tests/conftest.py +++ b/python/turbine_models/tests/conftest.py @@ -5,14 +5,14 @@ def pytest_addoption(parser): def pytest_generate_tests(metafunc): if "quantization" in metafunc.fixturenames: if metafunc.config.getoption("all"): - quantizations = ["int4", None] + quantizations = ["int4", "None"] else: - quantizations = ["int4"] + quantizations = ["None"] metafunc.parametrize("quantization", quantizations) if "precision" in metafunc.fixturenames: if metafunc.config.getoption("all"): precisions = ["f16", "f32"] else: - precisions = ["f16"] + precisions = ["f32"] metafunc.parametrize("precision", precisions) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 2ae38fc25..8a8eb7d1e 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -13,19 +13,41 @@ from typing import Literal -def test_export(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): +import os +import sys +import re + +from typing import Tuple + +os.environ["TORCH_LOGS"] = "dynamic" +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from torch.utils import _pytree as pytree +from shark_turbine.aot import * +from iree.compiler.ir import Context +from iree import runtime as ireert + +from turbine_models.custom_models import remap_gguf +import safetensors + +from tqdm import tqdm + + + +def test_export_vmfb(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): llama.export_transformer_model( hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, compile_to="vmfb", external_weights="safetensors", - # external_weight_file="medllama2_7b_f16_int4.safetensors", Do not export weights because this doesn't get quantized + # external_weight_file="Llama-2-7b-chat-hf-function-calling-v2_f16_int4.safetensors", Do not export weights because this doesn't get quantized quantization=quantization, precision=precision, device="llvm-cpu", target_triple="host", ) +# def test_export_safetensors(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): from turbine_models.gen_external_params.gen_external_params import ( gen_external_params, ) @@ -33,25 +55,50 @@ def test_export(quantization: Literal["int4", None], precision: Literal["f16", " gen_external_params( hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", quantization=quantization, - weight_path="medllama2_7b_f16_int4.safetensors", hf_auth_token=None, precision=precision, ) - # def run_vmfb_comparison(prompt, hf_auth_token, hf_model_name, vmfb_path, external_weight_file, break_on_eos=True): +# def test_run_vmfb(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): DEFAULT_PROMPT = """[INST] <> Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] """ + # cache reference output to avoid having to use torch + TORCH_REFERENCE_STRING = '''Hello! I'm just an AI assistant, I'm here to help you with any questions you may have. However, I must point out that the question "what are you?" is not clear or factually coherent. I'm just a language model, I don't have a physical body or personal identity, so I cannot be described as a "you" in the classical sense. Is there anything else I can help you with?''' + # torch_reference_string = get_torch_string( + # prompt=DEFAULT_PROMPT, + # hf_auth_token=None, + # hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + # tokens_to_compare=50, + # ) + # + torch_str = TORCH_REFERENCE_STRING + + + from .vmfb_comparison import get_turbine_vmfb_string, get_torch_string + + torch_str = get_torch_string( + prompt=DEFAULT_PROMPT, + hf_auth_token=None, + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + tokens_to_compare=50, + precision=precision, + quantization=quantization, + ) - from .vmfb_comparison import run_vmfb_comparison + # store torch string + # include precision and quantization in filename + torch_str_cache_path = f"python/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + with open(torch_str_cache_path, "w") as f: + f.write(torch_str) - turbine_str, torch_str = run_vmfb_comparison( + turbine_str = get_turbine_vmfb_string( prompt=DEFAULT_PROMPT, hf_auth_token=None, hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - vmfb_path="medllama2_7b.vmfb", - external_weight_file="medllama2_7b_f16_int4.safetensors", - break_on_eos=True, + vmfb_path="Llama_2_7b_chat_hf_function_calling_v2.vmfb", + external_weight_file="Llama_2_7b_chat_hf_function_calling_v2_f16_int4.safetensors", + tokens_to_compare=50, ) import difflib diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index 4d4dcc15b..a2ac9b528 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -21,14 +21,32 @@ MAX_STEP_SEQ = 4095 +BATCH_SIZE = 1 +MAX_STEP_SEQ = 4095 + + def torch_token_generator( - prompt, hf_model_name: str, hf_auth_token: str, break_on_eos=False + prompt, hf_model_name: str, + hf_auth_token: str, + break_on_eos=False, + precision="f32", + quantization="None", ): + if precision == "f16": + torch_dtype = torch.float16 + elif precision == "f32": + torch_dtype = torch.float32 + else: + raise ValueError("Invalid dtype, f16 or f32 supported") + + if quantization is not None and quantization.lower() != "none": + raise NotImplementedError("Quantization not supported for torch") + tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, use_fast=False, use_auth_token=hf_auth_token + hf_model_name, use_fast=False, use_auth_token=hf_auth_token, ) model = AutoModelForCausalLM.from_pretrained( - hf_model_name, torch_dtype=torch.float, use_auth_token=hf_auth_token + hf_model_name, torch_dtype=torch_dtype, use_auth_token=hf_auth_token ) initial_input = tokenizer(prompt, return_tensors="pt") @@ -131,30 +149,53 @@ def format_out(results): if next_token_tensor.item() == tokenizer.eos_token_id and break_on_eos: break - -def run_vmfb_comparison( +def get_torch_string( prompt, hf_auth_token, hf_model_name, - vmfb_path, - external_weight_file, - break_on_eos=True, + tokens_to_compare=50, ): - # Initialize generators with the prompt and specific arguments + + print("Using prompt:") print(prompt) + print("To generate torch reference string...") torch_gen = torch_token_generator( prompt=prompt, hf_auth_token=hf_auth_token, hf_model_name=hf_model_name, - break_on_eos=break_on_eos, + break_on_eos=True, + ) + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token ) print( "Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while." ) + # read until stopiteration torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) - del torch_gen + torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) + + return torch_str + + +def get_turbine_vmfb_string( + prompt, + hf_auth_token, + hf_model_name, + vmfb_path, + external_weight_file, + tokens_to_compare=50, +): + # Initialize generators with the prompt and specific arguments + # check if torch string cache exists + # cache is at python/turbine_models/tests/vmfb_comparison_cached_torch_output.txt + + # Decode and print the outputs + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) # Run turbine until an equal number of tokens has been generated print( @@ -166,18 +207,13 @@ def run_vmfb_comparison( vmfb_path=vmfb_path, external_weight_file=external_weight_file, hf_auth_token=hf_auth_token, - break_on_eos=break_on_eos, + break_on_eos=False, ) turbine_tokens = [] - for _ in tqdm(range(len(torch_tokens)), desc="Generating Turbine tokens"): + for _ in tqdm(range(tokens_to_compare), desc="Generating Turbine tokens"): token = next(turbine_gen) turbine_tokens.append(token) del turbine_gen - # Decode and print the outputs - tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, use_fast=False, use_auth_token=hf_auth_token - ) turbine_str = tokenizer.decode(torch.tensor(turbine_tokens).numpy()) - torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) - return turbine_str, torch_str + return turbine_str \ No newline at end of file From 228937368948065f99394a7f4f4b49feb4c814f5 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 8 Dec 2023 00:58:28 +0000 Subject: [PATCH 19/28] finally test cases passed and black applied --- .../custom_models/stateless_llama.py | 2 +- .../gen_external_params.py | 14 +++- python/turbine_models/tests/conftest.py | 18 ----- .../tests/stateless_llama_test.py | 68 +++++++++---------- .../turbine_models/tests/vmfb_comparison.py | 26 ++++--- ...on_cached_torch_output_f32_unquantized.txt | 1 + 6 files changed, 65 insertions(+), 64 deletions(-) delete mode 100644 python/turbine_models/tests/conftest.py create mode 100644 python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 7e1ad4341..0f3810b0c 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -34,7 +34,7 @@ help="HF model name", default="meta-llama/Llama-2-7b-chat-hf", ) -parser.add_argument("--quantization", type=str, default="None") +parser.add_argument("--quantization", type=str, default="unquantized") parser.add_argument("--external_weight_file", type=str, default="") parser.add_argument("--vmfb_path", type=str, default="") parser.add_argument( diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 5bf711994..f81f382fb 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -1,4 +1,5 @@ import re +from typing import Literal from turbine_models.model_builder import HFTransformerBuilder from transformers import AutoTokenizer, AutoModelForCausalLM import torch @@ -57,7 +58,7 @@ def forward(self, x): def gen_external_params( hf_model_name: str = "meta-llama/Llama-2-7b-chat-hf", - quantization: str = "int4", + quantization: Literal["unquantized", "int4", "int8"] = "int4", weight_path: str = "", hf_auth_token: str = None, precision: str = "f16", @@ -71,6 +72,17 @@ def gen_external_params( :param hf_auth_token: The Hugging Face auth token required for some models. :param precision: Data type of model ('f16' or 'f32'). """ + SUPPORTED_QUANTIZATIONS = ["unquantized", "int4", "int8"] + if quantization not in SUPPORTED_QUANTIZATIONS: + if ( + quantization is None + or quantization.lower() == "none" + or quantization.lower() == "unquantized" + ): + quantization = "unquantized" + else: + raise ValueError(f"Invalid quantization, {quantization} not supported.") + model_builder = HFTransformerBuilder( example_input=None, hf_id=hf_model_name, diff --git a/python/turbine_models/tests/conftest.py b/python/turbine_models/tests/conftest.py deleted file mode 100644 index 07da4cd47..000000000 --- a/python/turbine_models/tests/conftest.py +++ /dev/null @@ -1,18 +0,0 @@ -def pytest_addoption(parser): - parser.addoption("--all", action="store_true", help="run all combinations") - - -def pytest_generate_tests(metafunc): - if "quantization" in metafunc.fixturenames: - if metafunc.config.getoption("all"): - quantizations = ["int4", "None"] - else: - quantizations = ["None"] - metafunc.parametrize("quantization", quantizations) - - if "precision" in metafunc.fixturenames: - if metafunc.config.getoption("all"): - precisions = ["f16", "f32"] - else: - precisions = ["f32"] - metafunc.parametrize("precision", precisions) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 8a8eb7d1e..6a8b390a8 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -6,7 +6,6 @@ import logging import turbine_models.custom_models.stateless_llama as llama -import unittest import os import pytest @@ -31,10 +30,20 @@ import safetensors from tqdm import tqdm +from .vmfb_comparison import get_turbine_vmfb_string +def test_vmfb_comparison(): + """ + Test that the vmfb model produces the same output as the torch model + + Precision can be 16 or 32, using 16 for speed and memory. + + For VMFB, quantization can be int4 or None, but right now only using none for compatibility with torch. + """ + quantization = "unquantized" + precision = "f32" -def test_export_vmfb(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): llama.export_transformer_model( hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", hf_auth_token=None, @@ -47,7 +56,6 @@ def test_export_vmfb(quantization: Literal["int4", None], precision: Literal["f1 target_triple="host", ) -# def test_export_safetensors(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): from turbine_models.gen_external_params.gen_external_params import ( gen_external_params, ) @@ -59,48 +67,41 @@ def test_export_vmfb(quantization: Literal["int4", None], precision: Literal["f1 precision=precision, ) -# def test_run_vmfb(quantization: Literal["int4", None], precision: Literal["f16", "f32"]): DEFAULT_PROMPT = """[INST] <> Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] """ - # cache reference output to avoid having to use torch - TORCH_REFERENCE_STRING = '''Hello! I'm just an AI assistant, I'm here to help you with any questions you may have. However, I must point out that the question "what are you?" is not clear or factually coherent. I'm just a language model, I don't have a physical body or personal identity, so I cannot be described as a "you" in the classical sense. Is there anything else I can help you with?''' - # torch_reference_string = get_torch_string( - # prompt=DEFAULT_PROMPT, - # hf_auth_token=None, - # hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - # tokens_to_compare=50, - # ) - # - torch_str = TORCH_REFERENCE_STRING - - - from .vmfb_comparison import get_turbine_vmfb_string, get_torch_string - - torch_str = get_torch_string( - prompt=DEFAULT_PROMPT, - hf_auth_token=None, - hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - tokens_to_compare=50, - precision=precision, - quantization=quantization, - ) - # store torch string - # include precision and quantization in filename torch_str_cache_path = f"python/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" - with open(torch_str_cache_path, "w") as f: - f.write(torch_str) + # if cached, just read + if os.path.exists(torch_str_cache_path): + with open(torch_str_cache_path, "r") as f: + torch_str = f.read() + else: + from .vmfb_comparison import get_torch_string + + torch_str = get_torch_string( + prompt=DEFAULT_PROMPT, + hf_auth_token=None, + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + tokens_to_compare=50, + precision=precision, + quantization=quantization, + ) + + with open(torch_str_cache_path, "w") as f: + f.write(torch_str) turbine_str = get_turbine_vmfb_string( prompt=DEFAULT_PROMPT, hf_auth_token=None, hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", vmfb_path="Llama_2_7b_chat_hf_function_calling_v2.vmfb", - external_weight_file="Llama_2_7b_chat_hf_function_calling_v2_f16_int4.safetensors", + external_weight_file=f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", tokens_to_compare=50, ) + torch_str = torch_str[: len(turbine_str)] + import difflib # Calculate and print diff @@ -113,8 +114,3 @@ def test_export_vmfb(quantization: Literal["int4", None], precision: Literal["f1 ) assert torch_str == turbine_str, "".join(diff) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index a2ac9b528..ddefcc76c 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -26,11 +26,12 @@ def torch_token_generator( - prompt, hf_model_name: str, + prompt, + hf_model_name: str, hf_auth_token: str, break_on_eos=False, precision="f32", - quantization="None", + quantization="unquantized", ): if precision == "f16": torch_dtype = torch.float16 @@ -39,11 +40,17 @@ def torch_token_generator( else: raise ValueError("Invalid dtype, f16 or f32 supported") - if quantization is not None and quantization.lower() != "none": + if ( + quantization is not None + and quantization.lower() != "none" + and quantization.lower() != "unquantized" + ): raise NotImplementedError("Quantization not supported for torch") tokenizer = AutoTokenizer.from_pretrained( - hf_model_name, use_fast=False, use_auth_token=hf_auth_token, + hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, ) model = AutoModelForCausalLM.from_pretrained( hf_model_name, torch_dtype=torch_dtype, use_auth_token=hf_auth_token @@ -149,14 +156,15 @@ def format_out(results): if next_token_tensor.item() == tokenizer.eos_token_id and break_on_eos: break + def get_torch_string( prompt, hf_auth_token, hf_model_name, + precision, + quantization, tokens_to_compare=50, ): - - print("Using prompt:") print(prompt) print("To generate torch reference string...") @@ -165,6 +173,8 @@ def get_torch_string( hf_auth_token=hf_auth_token, hf_model_name=hf_model_name, break_on_eos=True, + precision=precision, + quantization=quantization, ) tokenizer = AutoTokenizer.from_pretrained( hf_model_name, use_fast=False, use_auth_token=hf_auth_token @@ -176,7 +186,7 @@ def get_torch_string( # read until stopiteration torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) - + return torch_str @@ -216,4 +226,4 @@ def get_turbine_vmfb_string( del turbine_gen turbine_str = tokenizer.decode(torch.tensor(turbine_tokens).numpy()) - return turbine_str \ No newline at end of file + return turbine_str diff --git a/python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt b/python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt new file mode 100644 index 000000000..2b5de4583 --- /dev/null +++ b/python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt @@ -0,0 +1 @@ +Hello! I'm just an AI assistant, I'm here to help you with any questions you may have. However, I must point out that the question "what are you?" is not clear or factually coherent. I'm just a language model, I don't have a physical body or personal identity, so I cannot be described as a "you" in the classical sense. Is there anything else I can help you with? \ No newline at end of file From 03c022b0cb9a1dcc8eb70e4c45e1108a411c4cb8 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 8 Dec 2023 01:28:23 +0000 Subject: [PATCH 20/28] make llama and sd test separate steps --- .github/workflows/test_models.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index e353273e2..fa6401dd2 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -37,6 +37,10 @@ jobs: pip install -e .[testing] pip install -e python/turbine_models - - name: Run tests + - name: Run stateless_llama tests run: | - pytest python/turbine_models/tests + pytest python/turbine_models/tests/test_stateless_llama.py + + - name: Run sd tests + run: | + pytest python/turbine_models/tests/sd_test.py From c00bc24a884d4e2933201d55da12529951868775 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 8 Dec 2023 01:31:49 +0000 Subject: [PATCH 21/28] typo --- .github/workflows/test_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index fa6401dd2..76c30bbab 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -39,7 +39,7 @@ jobs: - name: Run stateless_llama tests run: | - pytest python/turbine_models/tests/test_stateless_llama.py + pytest python/turbine_models/tests/stateless_llama_test.py - name: Run sd tests run: | From 91b3b72bd6f13b352d932b38ae9c6337f9a8e97f Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 8 Dec 2023 01:43:24 +0000 Subject: [PATCH 22/28] show mem availability --- .github/workflows/test_models.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 76c30bbab..b02c4a291 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -36,6 +36,10 @@ jobs: pip install --upgrade -r requirements.txt pip install -e .[testing] pip install -e python/turbine_models + + - name: Show current free memory + run: | + free -mh - name: Run stateless_llama tests run: | From 5e31126c99c9f36899098adcbb8def02d6f48020 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Fri, 8 Dec 2023 16:31:34 +0000 Subject: [PATCH 23/28] fix device issue --- python/turbine_models/tests/stateless_llama_test.py | 1 + python/turbine_models/tests/vmfb_comparison.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 6a8b390a8..c99cb7c23 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -98,6 +98,7 @@ def test_vmfb_comparison(): vmfb_path="Llama_2_7b_chat_hf_function_calling_v2.vmfb", external_weight_file=f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", tokens_to_compare=50, + device="llvm-cpu", ) torch_str = torch_str[: len(turbine_str)] diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index ddefcc76c..ba951b748 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -16,6 +16,7 @@ import safetensors from tqdm import tqdm +from typing import Literal BATCH_SIZE = 1 MAX_STEP_SEQ = 4095 @@ -73,6 +74,7 @@ def torch_token_generator( break + def turbine_token_generator( prompt: str, hf_model_name: str, @@ -80,6 +82,7 @@ def turbine_token_generator( external_weight_file: str = None, hf_auth_token: str = None, break_on_eos: bool = False, + device: Literal["llvm-cpu", "cuda", "vulcan", "rocm"] = "llvm-cpu", ) -> torch.Tensor: """ A generator function for turbine model inference. @@ -94,7 +97,7 @@ def turbine_token_generator( """ # Create the config for the IREE runtime environment - config = ireert.Config("local-task") + config = ireert.Config("local-task" if device == "llvm-cpu" else device) # Load the external weight file if provided if external_weight_file: @@ -196,6 +199,7 @@ def get_turbine_vmfb_string( hf_model_name, vmfb_path, external_weight_file, + device, tokens_to_compare=50, ): # Initialize generators with the prompt and specific arguments @@ -218,6 +222,7 @@ def get_turbine_vmfb_string( external_weight_file=external_weight_file, hf_auth_token=hf_auth_token, break_on_eos=False, + device=device, ) turbine_tokens = [] for _ in tqdm(range(tokens_to_compare), desc="Generating Turbine tokens"): From cf4b4b1173727cc91cef5c883ad394cafbcf1b41 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 11 Dec 2023 15:16:05 +0000 Subject: [PATCH 24/28] move args back to beginnign of file --- .../gen_external_params.py | 81 ++++++++++--------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index f81f382fb..703f1ff25 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -4,6 +4,47 @@ from transformers import AutoTokenizer, AutoModelForCausalLM import torch +import argparse +import sys + +parser = argparse.ArgumentParser( + description="Quantize and save Hugging Face models." +) + +parser.add_argument( + "--hf_model_name", + type=str, + default="meta-llama/Llama-2-7b-chat-hf", + help="The Hugging Face model name ID.", +) +parser.add_argument( + "--quantization", + type=str, + default="int4", + choices=["int4", "int8"], + help="Type of quantization to apply.", +) +parser.add_argument( + "--weight_path", + type=str, + default="", + help="Path to save the quantized model weights.", +) +parser.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="The Hugging Face auth token required for some models.", +) +parser.add_argument( + "--precision", + type=str, + default="f16", + choices=["f16", "f32"], + help="Data type of model.", +) + +args = parser.parse_args() def quantize(model, quantization, dtype): accumulates = dtype @@ -116,47 +157,7 @@ def gen_external_params( if __name__ == "__main__": - import argparse - import sys - parser = argparse.ArgumentParser( - description="Quantize and save Hugging Face models." - ) - - parser.add_argument( - "--hf_model_name", - type=str, - default="meta-llama/Llama-2-7b-chat-hf", - help="The Hugging Face model name ID.", - ) - parser.add_argument( - "--quantization", - type=str, - default="int4", - choices=["int4", "int8"], - help="Type of quantization to apply.", - ) - parser.add_argument( - "--weight_path", - type=str, - default="", - help="Path to save the quantized model weights.", - ) - parser.add_argument( - "--hf_auth_token", - type=str, - default=None, - help="The Hugging Face auth token required for some models.", - ) - parser.add_argument( - "--precision", - type=str, - default="f16", - choices=["f16", "f32"], - help="Data type of model.", - ) - - args = parser.parse_args() try: gen_external_params( From 979321c0ea754450858f438e9bed9da91a9c0a8f Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 11 Dec 2023 15:20:00 +0000 Subject: [PATCH 25/28] remove MAX_STEP_SEQ which is not needed for VMFB_COMPARISON --- python/turbine_models/tests/vmfb_comparison.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index ba951b748..9cd43365e 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -18,12 +18,6 @@ from tqdm import tqdm from typing import Literal -BATCH_SIZE = 1 -MAX_STEP_SEQ = 4095 - - -BATCH_SIZE = 1 -MAX_STEP_SEQ = 4095 def torch_token_generator( From b0006427654e4114790d5863eba8303bbfb030cf Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 11 Dec 2023 17:39:46 +0000 Subject: [PATCH 26/28] black --- .../gen_external_params/gen_external_params.py | 7 ++----- python/turbine_models/tests/vmfb_comparison.py | 2 -- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 703f1ff25..20f4ab556 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -7,9 +7,7 @@ import argparse import sys -parser = argparse.ArgumentParser( - description="Quantize and save Hugging Face models." -) +parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") parser.add_argument( "--hf_model_name", @@ -46,6 +44,7 @@ args = parser.parse_args() + def quantize(model, quantization, dtype): accumulates = dtype int_weights = {} @@ -157,8 +156,6 @@ def gen_external_params( if __name__ == "__main__": - - try: gen_external_params( hf_model_name=args.hf_model_name, diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py index 9cd43365e..112bb89f5 100644 --- a/python/turbine_models/tests/vmfb_comparison.py +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -19,7 +19,6 @@ from typing import Literal - def torch_token_generator( prompt, hf_model_name: str, @@ -68,7 +67,6 @@ def torch_token_generator( break - def turbine_token_generator( prompt: str, hf_model_name: str, From ad865e28ac64b8c5bd76c77450e810382a57bf34 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 11 Dec 2023 17:52:59 +0000 Subject: [PATCH 27/28] only parse args when in main --- .../turbine_models/gen_external_params/gen_external_params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 20f4ab556..7c8bf80f2 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -42,7 +42,6 @@ help="Data type of model.", ) -args = parser.parse_args() def quantize(model, quantization, dtype): @@ -156,6 +155,8 @@ def gen_external_params( if __name__ == "__main__": + + args = parser.parse_args() try: gen_external_params( hf_model_name=args.hf_model_name, From 3883224a0008c6f8a8aff04219a8113efef19d77 Mon Sep 17 00:00:00 2001 From: Xida Ren Date: Mon, 11 Dec 2023 17:53:51 +0000 Subject: [PATCH 28/28] black --- .../turbine_models/gen_external_params/gen_external_params.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 7c8bf80f2..df79adcc6 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -43,7 +43,6 @@ ) - def quantize(model, quantization, dtype): accumulates = dtype int_weights = {} @@ -155,7 +154,6 @@ def gen_external_params( if __name__ == "__main__": - args = parser.parse_args() try: gen_external_params(