diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 3227c2769..b02c4a291 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -1,4 +1,4 @@ -name: Test +name: Test Turbine Models on: workflow_dispatch: @@ -8,7 +8,7 @@ on: - main jobs: - test: + test-turbine-models: strategy: matrix: version: [3.11] @@ -36,7 +36,15 @@ jobs: pip install --upgrade -r requirements.txt pip install -e .[testing] pip install -e python/turbine_models + + - name: Show current free memory + run: | + free -mh + + - name: Run stateless_llama tests + run: | + pytest python/turbine_models/tests/stateless_llama_test.py - - name: Run tests + - name: Run sd tests run: | - pytest python/turbine_models/tests + pytest python/turbine_models/tests/sd_test.py diff --git a/python/shark_turbine/dynamo/passes.py b/python/shark_turbine/dynamo/passes.py index 1e8e2058f..478944df6 100644 --- a/python/shark_turbine/dynamo/passes.py +++ b/python/shark_turbine/dynamo/passes.py @@ -46,6 +46,7 @@ torch.ops.aten._to_copy, torch.ops.aten._log_softmax_backward_data, torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten._unsafe_index.Tensor, ] diff --git a/python/shark_turbine/importers/fx_importer.py b/python/shark_turbine/importers/fx_importer.py index e2707f35b..76d7f9978 100644 --- a/python/shark_turbine/importers/fx_importer.py +++ b/python/shark_turbine/importers/fx_importer.py @@ -628,6 +628,16 @@ def _import_torch_op_overload( elif target == torch.ops.aten.lift_fresh_copy.out: node.target = target = torch.ops.aten.clone.out node.args = (node.args[0], None, node.args[1]) + # TODO: generalize empty.memory_format in the future + # Currently, the aten.baddbmm.default op for Unet includes multiplying an + # empty.memory_format input with a constant, which creates NaN values + # because empty.memory_format contains uninitialized data. Converting + # aten.baddbmm.default -> aten.zeros.default fixes the correctness issue + elif target == torch.ops.aten.empty.memory_format: + if len(node.users) == 1: + for key_node in node.users: + if key_node.target == torch.ops.aten.baddbmm.default: + node.target = target = torch.ops.aten.zeros.default schema = target._schema assert isinstance(schema, FunctionSchema) diff --git a/python/turbine_models/custom_models/sd_inference/clip.py b/python/turbine_models/custom_models/sd_inference/clip.py new file mode 100644 index 000000000..ec2dcb3fb --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/clip.py @@ -0,0 +1,201 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import re + +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from transformers import CLIPTextModel, CLIPTokenizer + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument("--run_vmfb", action="store_true") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_file", type=str, default="") +parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + +prompt = ["a photograph of an astronaut riding a horse"] + + +def export_clip_model( + hf_model_name, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_file=None, + device=None, + target_triple=None, + max_alloc=None, +): + # Load the tokenizer and text encoder to tokenize and encode the text. + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + text_encoder_model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder="text_encoder", + token=hf_auth_token, + ) + + mapper = {} + utils.save_external_weights( + mapper, text_encoder_model, external_weights, external_weight_file + ) + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): + return jittable(text_encoder_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str, tokenizer + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +def run_clip_vmfb_comparison(args): + config = ireert.Config(args.device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if args.vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + sys.exit("no vmfb_path provided, required for run_vmfb") + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] + if args.external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=config, + ) + tokenizer = CLIPTokenizer.from_pretrained( + args.hf_model_name, + subfolder="tokenizer", + token=args.hf_auth_token, + ) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + inp = text_input.input_ids + device_inputs = [ireert.asdevicearray(config.device, inp)] + + # Turbine output + ModuleCompiled = ctx.modules.compiled_clip + turbine_outputs = ModuleCompiled["main"](*device_inputs) + turbine_output = turbine_outputs[0] + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + # Torch output + text_encoder_model = CLIPTextModel.from_pretrained( + args.hf_model_name, + subfolder="text_encoder", + token=args.hf_auth_token, + ) + torch_output = text_encoder_model.forward(inp)[0] + np_torch_output = torch_output.detach().cpu().numpy() + print( + "TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype + ) + + err = utils.largest_error(np_torch_output, turbine_output) + print("LARGEST ERROR:", err) + assert err < 9e-5 + + +if __name__ == "__main__": + args = parser.parse_args() + if args.run_vmfb: + run_clip_vmfb_comparison(args) + else: + mod_str, _ = export_clip_model( + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/unet.py b/python/turbine_models/custom_models/sd_inference/unet.py new file mode 100644 index 000000000..4c4d3c227 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/unet.py @@ -0,0 +1,215 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import re + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import UNet2DConditionModel + +import safetensors +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument("--run_vmfb", action="store_true") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_file", type=str, default="") +parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + + +class UnetModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + token=hf_auth_token, + ) + self.guidance_scale = 7.5 + + def forward(self, sample, timestep, encoder_hidden_states): + samples = torch.cat([sample] * 2) + unet_out = self.unet.forward( + samples, timestep, encoder_hidden_states, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + +def export_unet_model( + unet_model, + hf_model_name, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_file=None, + device=None, + target_triple=None, + max_alloc=None, +): + mapper = {} + utils.save_external_weights( + mapper, unet_model, external_weights, external_weight_file + ) + + encoder_hidden_states_sizes = (2, 77, 768) + if hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states_sizes = (2, 77, 1024) + + class CompiledUnet(CompiledModule): + if external_weights: + params = export_parameters( + unet_model, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(unet_model) + + def main( + self, + sample=AbstractTensor(1, 4, 64, 64, dtype=torch.float32), + timestep=AbstractTensor(1, dtype=torch.float32), + encoder_hidden_states=AbstractTensor( + *encoder_hidden_states_sizes, dtype=torch.float32 + ), + ): + return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledUnet(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +def run_unet_vmfb_comparison(unet_model, args): + config = ireert.Config(args.device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if args.vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + sys.exit("no vmfb_path provided, required for run_vmfb") + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] + if args.external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=config, + ) + sample = torch.rand(1, 4, 64, 64, dtype=torch.float32) + timestep = torch.zeros(1, dtype=torch.float32) + if args.hf_model_name == "CompVis/stable-diffusion-v1-4": + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + + device_inputs = [ + ireert.asdevicearray(config.device, sample), + ireert.asdevicearray(config.device, timestep), + ireert.asdevicearray(config.device, encoder_hidden_states), + ] + + # Turbine output + ModuleCompiled = ctx.modules.compiled_unet + turbine_output = ModuleCompiled["main"](*device_inputs) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + # Torch output + torch_output = unet_model.forward(sample, timestep, encoder_hidden_states) + np_torch_output = torch_output.detach().cpu().numpy() + print( + "TORCH OUTPUT:", np_torch_output, np_torch_output.shape, np_torch_output.dtype + ) + + err = utils.largest_error(np_torch_output, turbine_output) + print("LARGEST ERROR:", err) + assert err < 9e-5 + + +if __name__ == "__main__": + args = parser.parse_args() + unet_model = UnetModel( + args.hf_model_name, + args.hf_auth_token, + ) + if args.run_vmfb: + run_unet_vmfb_comparison(unet_model, args) + else: + mod_str = export_unet_model( + unet_model, + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/sd_inference/utils.py b/python/turbine_models/custom_models/sd_inference/utils.py new file mode 100644 index 000000000..e2f4cfd96 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/utils.py @@ -0,0 +1,83 @@ +import iree.compiler as ireec +import numpy as np +import safetensors + + +def save_external_weights( + mapper, + model, + external_weights=None, + external_weight_file=None, +): + if external_weights is not None: + if external_weights == "safetensors": + mod_params = dict(model.named_parameters()) + for name in mod_params: + mapper["params." + name] = name + if external_weight_file: + safetensors.torch.save_file(mod_params, external_weight_file) + print("Saved params to", external_weight_file) + + +def largest_error(array1, array2): + absolute_diff = np.abs(array1 - array2) + max_error = np.max(absolute_diff) + return max_error + + +def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): + flags = [ + "--iree-input-type=torch", + "--mlir-print-debuginfo", + "--mlir-print-op-on-diagnostic=false", + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-target-triple=x86_64-linux-gnu", + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-opt-const-expr-hoisting=False", + ] + if device == "cpu": + flags.append("--iree-llvmcpu-enable-ukernels=all") + device = "llvm-cpu" + elif device == "vulkan": + flags.extend( + [ + "--iree-hal-target-backends=vulkan-spirv", + "--iree-vulkan-target-triple=" + target_triple, + "--iree-stream-resource-max-allocation-size=" + max_alloc, + ] + ) + elif device == "rocm": + flags.extend( + [ + "--iree-hal-target-backends=rocm", + "--iree-rocm-target-chip=" + target_triple, + "--iree-rocm-link-bc=true", + "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", + "--iree-vm-bytecode-module-strip-source-map=true", + "--iree-opt-strip-assertions=true", + "--iree-vm-target-truncate-unsupported-floats", + ] + ) + elif device == "cuda": + flags.extend( + [ + "--iree-hal-target-backends=cuda", + "--iree-hal-cuda-llvm-target-arch=" + target_triple, + "--iree-vm-bytecode-module-strip-source-map=true", + "--iree-vm-target-truncate-unsupsported-floats", + ] + ) + else: + print("incorrect device: ", device) + + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + extra_args=flags, + ) + with open(f"{safe_name}.vmfb", "wb+") as f: + f.write(flatbuffer_blob) + print("Saved to", safe_name + ".vmfb") + exit() diff --git a/python/turbine_models/custom_models/sd_inference/vae.py b/python/turbine_models/custom_models/sd_inference/vae.py new file mode 100644 index 000000000..cf3c587a1 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/vae.py @@ -0,0 +1,180 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys +import re + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import AutoencoderKL + +import safetensors +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument( + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" +) +parser.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="CompVis/stable-diffusion-v1-4", +) +parser.add_argument("--run_vmfb", action="store_true") +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") +parser.add_argument("--external_weight_file", type=str, default="") +parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--external_weights", + type=str, + default=None, + help="saves ir/vmfb without global weights for size and readability, options [safetensors]", +) +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +# TODO: Bring in detection for target triple +parser.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") + + +class VaeModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + token=hf_auth_token, + ) + + def forward(self, inp): + with torch.no_grad(): + x = self.vae.decode(inp, return_dict=False)[0] + return x + + +def export_vae_model( + vae_model, + hf_model_name, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_file=None, + device=None, + target_triple=None, + max_alloc=None, +): + mapper = {} + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_file + ) + + class CompiledVae(CompiledModule): + params = export_parameters(vae_model) + + def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)): + return jittable(vae_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledVae(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) + + +def run_vae_vmfb_comparison(vae_model, args): + config = ireert.Config(args.device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + if args.vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + sys.exit("no vmfb_path provided, required for run_vmfb") + + vm_modules = [ + mod, + ireert.create_hal_module(config.vm_instance, config.device), + ] + if args.external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + ctx = ireert.SystemContext( + vm_modules=vm_modules, + config=config, + ) + inp = torch.rand(1, 4, 64, 64, dtype=torch.float32) + device_inputs = [ireert.asdevicearray(config.device, inp)] + + # Turbine output + ModuleCompiled = ctx.modules.compiled_vae + turbine_output = ModuleCompiled["main"](*device_inputs) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + # Torch output + torch_output = vae_model.forward(inp) + torch_output = torch_output.detach().cpu().numpy() + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + + err = utils.largest_error(torch_output, turbine_output) + print("LARGEST ERROR:", err) + assert err < 9e-5 + + +if __name__ == "__main__": + args = parser.parse_args() + vae_model = VaeModel( + args.hf_model_name, + args.hf_auth_token, + ) + if args.run_vmfb: + run_vae_vmfb_comparison(vae_model, args) + else: + mod_str = export_vae_model( + vae_model, + args.hf_model_name, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, + ) + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index 0d3d749d2..0f3810b0c 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -19,7 +19,6 @@ import argparse parser = argparse.ArgumentParser() -parser.add_argument("--run_vmfb", action="store_true") parser.add_argument( "--hf_auth_token", type=str, help="The Hugging Face auth token, required" ) @@ -35,7 +34,7 @@ help="HF model name", default="meta-llama/Llama-2-7b-chat-hf", ) -parser.add_argument("--quantization", type=str, default="None") +parser.add_argument("--quantization", type=str, default="unquantized") parser.add_argument("--external_weight_file", type=str, default="") parser.add_argument("--vmfb_path", type=str, default="") parser.add_argument( @@ -48,12 +47,14 @@ "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) -parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") +parser.add_argument( + "--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm" +) # TODO: Bring in detection for target triple parser.add_argument( "--iree_target_triple", type=str, - default="", + default="host", help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") @@ -91,7 +92,7 @@ def export_transformer_model( precision=None, device=None, target_triple=None, - max_alloc=None, + vulkan_max_allocation=None, ): state_schema = pytree.treespec_loads(json_schema) @@ -236,14 +237,15 @@ def forward(token0: torch.Tensor, *state0_flat): "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-opt-const-expr-hoisting=False", ] - if device == "cpu": - flags.append("--iree-llvmcpu-enable-microkernels") + if device == "cpu" or device == "llvm-cpu": + flags.append("--iree-llvmcpu-enable-ukernels=all") device = "llvm-cpu" elif device == "vulkan": flags.extend( [ "--iree-vulkan-target-triple=" + target_triple, - "--iree-stream-resource-max-allocation-size=" + max_alloc, + "--iree-stream-resource-max-allocation-size=" + + vulkan_max_allocation, ] ) elif device == "rocm": @@ -277,109 +279,28 @@ def forward(token0: torch.Tensor, *state0_flat): with open(f"{safe_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) print("saved to ", safe_name + ".vmfb") - exit() - - -def run_vmfb_comparison(args): - config = ireert.Config(args.device) - - if args.external_weight_file: - index = ireert.ParameterIndex() - index.load(args.external_weight_file) - - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - if args.vmfb_path: - mod = ireert.VmModule.mmap(config.vm_instance, args.vmfb_path) - elif os.path.exists(f"{safe_name}.vmfb"): - mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") - else: - sys.exit("no vmfb_path provided, required for run_vmfb") - - vm_modules = [ - mod, - ireert.create_hal_module(config.vm_instance, config.device), - ] - if args.external_weight_file: - param_module = ireert.create_io_parameters_module( - config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) + return module_str, tokenizer - ctx = ireert.SystemContext( - vm_modules=vm_modules, - config=config, - ) - tokenizer = AutoTokenizer.from_pretrained( - args.hf_model_name, - use_fast=False, - token=args.hf_auth_token, - ) - initial_input = tokenizer(prompt, return_tensors="pt") - example_input_id = initial_input.input_ids - device_inputs = [ireert.asdevicearray(config.device, example_input_id)] - ModuleCompiled = ctx.modules.state_update - results = ModuleCompiled["run_initialize"](*device_inputs) +# if you're looking for run_vmfb_comparison, it's now in python/turbine_models/tests/vmfb_comparison.py - def format_out(results): - return torch.tensor(results.to_host()[0][0]) +if __name__ == "__main__": + args = parser.parse_args() - model = AutoModelForCausalLM.from_pretrained( + mod_str, _ = export_transformer_model( args.hf_model_name, - torch_dtype=torch.float, - use_auth_token=args.hf_auth_token, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_file, + args.quantization, + args.precision, + args.device, + args.iree_target_triple, + args.vulkan_max_allocation, ) - - def get_token_from_logits(logits): - return torch.argmax(logits[:, -1, :], dim=1) - - base_model_results = model.forward(example_input_id) - base_model_token = get_token_from_logits(base_model_results.logits) - bm_pkv = base_model_results.past_key_values - turbine_results = [] - torch_results = [] - turbine_results.append(format_out(results)) - torch_results.append(int(base_model_token)) - while base_model_token != 2: - results = ModuleCompiled["run_forward"](results) - base_model_results = model.forward( - torch.unsqueeze(base_model_token, 0), past_key_values=bm_pkv - ) - base_model_token = get_token_from_logits(base_model_results.logits) - - bm_pkv = base_model_results.past_key_values - # uncomment to see tokens as they are emittd - # print(f"pytorch: {tokenizer.decode(base_model_token)}") - # print(f"turbine: {tokenizer.decode(format_out(results))}") - turbine_results.append(format_out(results)) - torch_results.append(int(base_model_token[0])) - - print("turbine output: ") - print(tokenizer.decode(turbine_results)) - print("\ntorch output: ") - print(tokenizer.decode(torch_results)) - - -if __name__ == "__main__": - args = parser.parse_args() - if args.run_vmfb: - run_vmfb_comparison(args) - else: - mod_str, _ = export_transformer_model( - args.hf_model_name, - args.hf_auth_token, - args.compile_to, - args.external_weights, - args.external_weight_file, - args.quantization, - args.precision, - args.device, - args.iree_target_triple, - args.vulkan_max_allocation, - ) - safe_name = args.hf_model_name.split("/")[-1].strip() - safe_name = re.sub("-", "_", safe_name) - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to ", safe_name + ".mlir") + safe_name = args.hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to ", safe_name + ".mlir") diff --git a/python/turbine_models/gen_external_params/gen_external_params.py b/python/turbine_models/gen_external_params/gen_external_params.py index 7e993de72..df79adcc6 100644 --- a/python/turbine_models/gen_external_params/gen_external_params.py +++ b/python/turbine_models/gen_external_params/gen_external_params.py @@ -1,24 +1,45 @@ import re +from typing import Literal from turbine_models.model_builder import HFTransformerBuilder from transformers import AutoTokenizer, AutoModelForCausalLM import torch import argparse +import sys + +parser = argparse.ArgumentParser(description="Quantize and save Hugging Face models.") -parser = argparse.ArgumentParser() parser.add_argument( "--hf_model_name", type=str, - help="HF model name ID", default="meta-llama/Llama-2-7b-chat-hf", + help="The Hugging Face model name ID.", ) -parser.add_argument("--quantization", type=str, default="int4") -parser.add_argument("--weight_path", type=str, default="") parser.add_argument( - "--hf_auth_token", type=str, help="The HF auth token required for some models" + "--quantization", + type=str, + default="int4", + choices=["int4", "int8"], + help="Type of quantization to apply.", ) parser.add_argument( - "--precision", type=str, default="f16", help="Data type of model [f16, f32]" + "--weight_path", + type=str, + default="", + help="Path to save the quantized model weights.", +) +parser.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="The Hugging Face auth token required for some models.", +) +parser.add_argument( + "--precision", + type=str, + default="f16", + choices=["f16", "f32"], + help="Data type of model.", ) @@ -73,34 +94,75 @@ def forward(self, x): return all_weights -if __name__ == "__main__": - args = parser.parse_args() +def gen_external_params( + hf_model_name: str = "meta-llama/Llama-2-7b-chat-hf", + quantization: Literal["unquantized", "int4", "int8"] = "int4", + weight_path: str = "", + hf_auth_token: str = None, + precision: str = "f16", +): + """ + Main function to run the model quantization and saving process. + + :param hf_model_name: The Hugging Face model name ID. + :param quantization: Type of quantization to apply ('int4' or 'int8'). + :param weight_path: Path to save the quantized model weights. + :param hf_auth_token: The Hugging Face auth token required for some models. + :param precision: Data type of model ('f16' or 'f32'). + """ + SUPPORTED_QUANTIZATIONS = ["unquantized", "int4", "int8"] + if quantization not in SUPPORTED_QUANTIZATIONS: + if ( + quantization is None + or quantization.lower() == "none" + or quantization.lower() == "unquantized" + ): + quantization = "unquantized" + else: + raise ValueError(f"Invalid quantization, {quantization} not supported.") + model_builder = HFTransformerBuilder( example_input=None, - hf_id=args.hf_model_name, + hf_id=hf_model_name, auto_model=AutoModelForCausalLM, - hf_auth_token=args.hf_auth_token, + hf_auth_token=hf_auth_token, ) model_builder.build_model() - if args.precision == "f16": + + if precision == "f16": model = model_builder.model.half() dtype = torch.float16 - elif args.precision == "f32": + elif precision == "f32": model = model_builder.model dtype = torch.float32 else: - sys.exit("invalid precision, f16 or f32 supported") - quant_weights = quantize(model, args.quantization, dtype) - # TODO: Add more than just safetensor support - import safetensors + sys.exit("Invalid precision, f16 or f32 supported") + + quant_weights = quantize(model, quantization, dtype) - if args.weight_path == "": - save_path = args.hf_model_name.split("/")[-1].strip() + if weight_path == "": + save_path = hf_model_name.split("/")[-1].strip() save_path = re.sub("-", "_", save_path) - save_path = ( - save_path + "_" + args.precision + "_" + args.quantization + ".safetensors" - ) + save_path = save_path + "_" + precision + "_" + quantization + ".safetensors" else: - save_path = args.weight_path + save_path = weight_path + + import safetensors + safetensors.torch.save_file(quant_weights, save_path) print("Saved safetensor output to ", save_path) + + +if __name__ == "__main__": + args = parser.parse_args() + try: + gen_external_params( + hf_model_name=args.hf_model_name, + quantization=args.quantization, + weight_path=args.weight_path, + hf_auth_token=args.hf_auth_token, + precision=args.precision, + ) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) diff --git a/python/turbine_models/setup.py b/python/turbine_models/setup.py index 33aa3fdb5..4fbc0d256 100644 --- a/python/turbine_models/setup.py +++ b/python/turbine_models/setup.py @@ -63,5 +63,7 @@ def load_version_info(): "protobuf", "sentencepiece", "transformers", + "accelerate", + "diffusers==0.10.2", ], ) diff --git a/python/turbine_models/tests/gen_external_params_test.py b/python/turbine_models/tests/gen_external_params_test.py new file mode 100644 index 000000000..503a2af6a --- /dev/null +++ b/python/turbine_models/tests/gen_external_params_test.py @@ -0,0 +1,132 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from turbine_models.gen_external_params.gen_external_params import quantize +from turbine_models.model_builder import HFTransformerBuilder +from transformers import AutoTokenizer, AutoModelForCausalLM +import unittest +import os +import torch +import pytest + + +class ExternalParamsTest(unittest.TestCase): + def testQuantizeF32(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model, "", torch.float32) + for weight in quant_weights: + self.assertNotIn("weight_zp", weight) + self.assertNotIn("weight_scale", weight) + assert quant_weights[weight].dtype in [torch.float32] + + def testQuantizeF32I8(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model, "int8", torch.float32) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float32] + else: + assert quant_weights[weight].dtype in [torch.float32] + + def testQuantizeF32I4(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model, "int4", torch.float32) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == 2 * quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float32] + else: + assert quant_weights[weight].dtype in [torch.float32] + + def testQuantizeF16(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model.half(), "", torch.float16) + for weight in quant_weights: + self.assertNotIn("weight_zp", weight) + self.assertNotIn("weight_scale", weight) + assert quant_weights[weight].dtype in [torch.float16] + + @pytest.mark.xfail(reason="brevitas issue with f16 int8 quanttensor") + def testQuantizeF16I8(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model.half(), "int8", torch.float16) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float16] + else: + assert quant_weights[weight].dtype in [torch.float16] + + def testQuantizeF16I4(self): + model_builder = HFTransformerBuilder( + example_input=None, + hf_id="facebook/opt-350m", + auto_model=AutoModelForCausalLM, + ) + model_builder.build_model() + quant_weights = quantize(model_builder.model.half(), "int4", torch.float16) + named_params = dict(model_builder.model.named_parameters()) + for weight in quant_weights: + if "weight_scale" not in weight and "weight_zp" not in weight: + if "layers" in weight and "weight" in weight and "norm" not in weight: + assert quant_weights[weight].dtype in [torch.uint8] + assert named_params[weight].size(dim=1) == 2 * quant_weights[ + weight + ].size(dim=1) + else: + assert quant_weights[weight].dtype in [torch.float16] + else: + assert quant_weights[weight].dtype in [torch.float16] + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/python/turbine_models/tests/llama_test.py b/python/turbine_models/tests/llama_test.py deleted file mode 100644 index df44a0e70..000000000 --- a/python/turbine_models/tests/llama_test.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2023 Nod Labs, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import logging -import turbine_models.custom_models.stateless_llama as llama -import unittest -import os - - -class LLamaTest(unittest.TestCase): - def testExportTransformerModel(self): - llama.export_transformer_model( - # This is a public model, so no auth required - "llSourcell/medllama2_7b", - None, - "torch", - "safetensors", - "medllama2_f32.safetensors", - None, - "f32", - ) - os.remove("medllama2_f32.safetensors") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/python/turbine_models/tests/sd_test.py b/python/turbine_models/tests/sd_test.py new file mode 100644 index 000000000..c3834840d --- /dev/null +++ b/python/turbine_models/tests/sd_test.py @@ -0,0 +1,101 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import logging +from turbine_models.custom_models.sd_inference import clip, unet, vae +import unittest +import os + + +arguments = { + "hf_auth_token": None, + "hf_model_name": "CompVis/stable-diffusion-v1-4", + "run_vmfb": True, + "compile_to": None, + "external_weight_file": "", + "vmfb_path": "", + "external_weights": None, + "device": "local-task", + "iree_target_triple": "", + "vulkan_max_allocation": "4294967296", +} + + +unet_model = unet.UnetModel( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, +) + +vae_model = vae.VaeModel( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, +) + + +class StableDiffusionTest(unittest.TestCase): + def testExportClipModel(self): + with self.assertRaises(SystemExit) as cm: + clip.export_clip_model( + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_clip.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_file"] = "stable_diffusion_v1_4_clip.safetensors" + namespace = argparse.Namespace(**arguments) + clip.run_clip_vmfb_comparison(namespace) + os.remove("stable_diffusion_v1_4_clip.safetensors") + os.remove("stable_diffusion_v1_4.vmfb") + + def testExportUnetModel(self): + with self.assertRaises(SystemExit) as cm: + unet.export_unet_model( + unet_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_unet.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_file"] = "stable_diffusion_v1_4_unet.safetensors" + namespace = argparse.Namespace(**arguments) + unet.run_unet_vmfb_comparison(unet_model, namespace) + os.remove("stable_diffusion_v1_4_unet.safetensors") + os.remove("stable_diffusion_v1_4.vmfb") + + def testExportVaeModel(self): + with self.assertRaises(SystemExit) as cm: + vae.export_vae_model( + vae_model, + # This is a public model, so no auth required + "CompVis/stable-diffusion-v1-4", + None, + "vmfb", + "safetensors", + "stable_diffusion_v1_4_vae.safetensors", + "cpu", + ) + self.assertEqual(cm.exception.code, None) + arguments["external_weight_file"] = "stable_diffusion_v1_4_vae.safetensors" + namespace = argparse.Namespace(**arguments) + vae.run_vae_vmfb_comparison(vae_model, namespace) + os.remove("stable_diffusion_v1_4_vae.safetensors") + os.remove("stable_diffusion_v1_4.vmfb") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py new file mode 100644 index 000000000..c99cb7c23 --- /dev/null +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -0,0 +1,117 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import turbine_models.custom_models.stateless_llama as llama +import os +import pytest + +from typing import Literal + + +import os +import sys +import re + +from typing import Tuple + +os.environ["TORCH_LOGS"] = "dynamic" +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from torch.utils import _pytree as pytree +from shark_turbine.aot import * +from iree.compiler.ir import Context +from iree import runtime as ireert + +from turbine_models.custom_models import remap_gguf +import safetensors + +from tqdm import tqdm +from .vmfb_comparison import get_turbine_vmfb_string + + +def test_vmfb_comparison(): + """ + Test that the vmfb model produces the same output as the torch model + + Precision can be 16 or 32, using 16 for speed and memory. + + For VMFB, quantization can be int4 or None, but right now only using none for compatibility with torch. + """ + quantization = "unquantized" + precision = "f32" + + llama.export_transformer_model( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + # external_weight_file="Llama-2-7b-chat-hf-function-calling-v2_f16_int4.safetensors", Do not export weights because this doesn't get quantized + quantization=quantization, + precision=precision, + device="llvm-cpu", + target_triple="host", + ) + + from turbine_models.gen_external_params.gen_external_params import ( + gen_external_params, + ) + + gen_external_params( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + quantization=quantization, + hf_auth_token=None, + precision=precision, + ) + + DEFAULT_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] +""" + + torch_str_cache_path = f"python/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + # if cached, just read + if os.path.exists(torch_str_cache_path): + with open(torch_str_cache_path, "r") as f: + torch_str = f.read() + else: + from .vmfb_comparison import get_torch_string + + torch_str = get_torch_string( + prompt=DEFAULT_PROMPT, + hf_auth_token=None, + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + tokens_to_compare=50, + precision=precision, + quantization=quantization, + ) + + with open(torch_str_cache_path, "w") as f: + f.write(torch_str) + + turbine_str = get_turbine_vmfb_string( + prompt=DEFAULT_PROMPT, + hf_auth_token=None, + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + vmfb_path="Llama_2_7b_chat_hf_function_calling_v2.vmfb", + external_weight_file=f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", + tokens_to_compare=50, + device="llvm-cpu", + ) + + torch_str = torch_str[: len(turbine_str)] + + import difflib + + # Calculate and print diff + diff = difflib.unified_diff( + torch_str.splitlines(keepends=True), + turbine_str.splitlines(keepends=True), + fromfile="torch_str", + tofile="turbine_str", + lineterm="", + ) + + assert torch_str == turbine_str, "".join(diff) diff --git a/python/turbine_models/tests/vmfb_comparison.py b/python/turbine_models/tests/vmfb_comparison.py new file mode 100644 index 000000000..112bb89f5 --- /dev/null +++ b/python/turbine_models/tests/vmfb_comparison.py @@ -0,0 +1,226 @@ +import os +import sys +import re + +from typing import Tuple + +os.environ["TORCH_LOGS"] = "dynamic" +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from torch.utils import _pytree as pytree +from shark_turbine.aot import * +from iree.compiler.ir import Context +from iree import runtime as ireert + +from turbine_models.custom_models import remap_gguf +import safetensors + +from tqdm import tqdm +from typing import Literal + + +def torch_token_generator( + prompt, + hf_model_name: str, + hf_auth_token: str, + break_on_eos=False, + precision="f32", + quantization="unquantized", +): + if precision == "f16": + torch_dtype = torch.float16 + elif precision == "f32": + torch_dtype = torch.float32 + else: + raise ValueError("Invalid dtype, f16 or f32 supported") + + if ( + quantization is not None + and quantization.lower() != "none" + and quantization.lower() != "unquantized" + ): + raise NotImplementedError("Quantization not supported for torch") + + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + model = AutoModelForCausalLM.from_pretrained( + hf_model_name, torch_dtype=torch_dtype, use_auth_token=hf_auth_token + ) + + initial_input = tokenizer(prompt, return_tensors="pt") + input_ids = initial_input.input_ids + past_key_values = None + + while True: + model_results = model.forward(input_ids, past_key_values=past_key_values) + logits = model_results.logits + next_token_id = torch.argmax(logits[:, -1, :], dim=1) + past_key_values = model_results.past_key_values + + yield next_token_id + input_ids = next_token_id.unsqueeze(0) # Prepare for the next iteration + + if next_token_id.item() == tokenizer.eos_token_id and break_on_eos: + break + + +def turbine_token_generator( + prompt: str, + hf_model_name: str, + vmfb_path: str = None, + external_weight_file: str = None, + hf_auth_token: str = None, + break_on_eos: bool = False, + device: Literal["llvm-cpu", "cuda", "vulcan", "rocm"] = "llvm-cpu", +) -> torch.Tensor: + """ + A generator function for turbine model inference. + + :param prompt: The input prompt for the model. + :param hf_model_name: The name of the Hugging Face model. + :param vmfb_path: Path to the .vmfb model file. + :param external_weight_file: Path to the external weight file (optional). + :param hf_auth_token: Hugging Face authorization token (optional). + :param break_on_eos: Whether to break the loop on end-of-sentence token. + :return: Yields a tensor representing the generated token. + """ + + # Create the config for the IREE runtime environment + config = ireert.Config("local-task" if device == "llvm-cpu" else device) + + # Load the external weight file if provided + if external_weight_file: + index = ireert.ParameterIndex() + index.load(external_weight_file) + + # Ensure model name is in a safe format + safe_name = hf_model_name.split("/")[-1].strip() + safe_name = re.sub("-", "_", safe_name) + + # Load the .vmfb model file + if vmfb_path: + mod = ireert.VmModule.mmap(config.vm_instance, vmfb_path) + elif os.path.exists(f"{safe_name}.vmfb"): + mod = ireert.VmModule.mmap(config.vm_instance, f"{safe_name}.vmfb") + else: + raise FileNotFoundError("No vmfb_path provided, required for run_vmfb") + + # Prepare the modules for the IREE runtime context + vm_modules = [mod, ireert.create_hal_module(config.vm_instance, config.device)] + + # Include parameter module if external weight file is used + if external_weight_file: + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(0, param_module) + + # Create the system context with the given configuration and modules + ctx = ireert.SystemContext(vm_modules=vm_modules, config=config) + + # Initialize the tokenizer + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + + # Convert the prompt to input tensor + initial_input = tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + device_inputs = [ireert.asdevicearray(config.device, example_input_id)] + + # Get the compiled module + ModuleCompiled = ctx.modules.state_update + results = ModuleCompiled["run_initialize"](*device_inputs) + + def format_out(results): + # Convert the output to a PyTorch tensor + return torch.tensor(results.to_host()[0][0]) + + # Token generation loop + while True: + next_token_tensor = format_out(results) + yield next_token_tensor.item() # Yield the scalar value of the tensor + + # Run the next step of the model + results = ModuleCompiled["run_forward"](results) + + # Check for the end-of-sentence token + if next_token_tensor.item() == tokenizer.eos_token_id and break_on_eos: + break + + +def get_torch_string( + prompt, + hf_auth_token, + hf_model_name, + precision, + quantization, + tokens_to_compare=50, +): + print("Using prompt:") + print(prompt) + print("To generate torch reference string...") + torch_gen = torch_token_generator( + prompt=prompt, + hf_auth_token=hf_auth_token, + hf_model_name=hf_model_name, + break_on_eos=True, + precision=precision, + quantization=quantization, + ) + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + + print( + "Generating Torch tokens... The pipeline needs to be initialized first so the first few tokens may take a while." + ) + # read until stopiteration + torch_tokens = list(tqdm(torch_gen, desc="Generating Torch tokens")) + torch_str = tokenizer.decode(torch.tensor(torch_tokens).numpy()) + + return torch_str + + +def get_turbine_vmfb_string( + prompt, + hf_auth_token, + hf_model_name, + vmfb_path, + external_weight_file, + device, + tokens_to_compare=50, +): + # Initialize generators with the prompt and specific arguments + # check if torch string cache exists + # cache is at python/turbine_models/tests/vmfb_comparison_cached_torch_output.txt + + # Decode and print the outputs + tokenizer = AutoTokenizer.from_pretrained( + hf_model_name, use_fast=False, use_auth_token=hf_auth_token + ) + + # Run turbine until an equal number of tokens has been generated + print( + "Generating Turbine tokens... The pipeline needs to be initialized first so the first few tokens may take a while." + ) + turbine_gen = turbine_token_generator( + prompt=prompt, + hf_model_name=hf_model_name, + vmfb_path=vmfb_path, + external_weight_file=external_weight_file, + hf_auth_token=hf_auth_token, + break_on_eos=False, + device=device, + ) + turbine_tokens = [] + for _ in tqdm(range(tokens_to_compare), desc="Generating Turbine tokens"): + token = next(turbine_gen) + turbine_tokens.append(token) + del turbine_gen + + turbine_str = tokenizer.decode(torch.tensor(turbine_tokens).numpy()) + return turbine_str diff --git a/python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt b/python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt new file mode 100644 index 000000000..2b5de4583 --- /dev/null +++ b/python/turbine_models/tests/vmfb_comparison_cached_torch_output_f32_unquantized.txt @@ -0,0 +1 @@ +Hello! I'm just an AI assistant, I'm here to help you with any questions you may have. However, I must point out that the question "what are you?" is not clear or factually coherent. I'm just a language model, I don't have a physical body or personal identity, so I cannot be described as a "you" in the classical sense. Is there anything else I can help you with? \ No newline at end of file diff --git a/turbine-models-requirements.txt b/turbine-models-requirements.txt index 606fea685..c93f5dcec 100644 --- a/turbine-models-requirements.txt +++ b/turbine-models-requirements.txt @@ -2,4 +2,6 @@ protobuf sentencepiece shark_turbine transformers +accelerate +diffusers==0.10.2 brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b