diff --git a/.gitignore b/.gitignore index b564b2eb8..d85c8598b 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,10 @@ _python_build/ dist/ wheelhouse *.egg-info -*.whl \ No newline at end of file +*.whl + +#Model artifacts +*.pt +*.safetensors +*.gguf +*.vmfb diff --git a/python/shark_turbine/aot/compiled_module.py b/python/shark_turbine/aot/compiled_module.py index 9808ffeb4..ee8eb5490 100644 --- a/python/shark_turbine/aot/compiled_module.py +++ b/python/shark_turbine/aot/compiled_module.py @@ -334,6 +334,7 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): continue del_attr_keys.add(key) info.def_attribute(key, value) + for key in del_attr_keys: del dct[key] @@ -343,6 +344,17 @@ def __new__(mcls, name: str, bases, dct, *, export_name: Optional[str] = None): if key not in dct: dct[key] = _blackhole_instance_attribute + # Inheriting methods, globals, and export from parent class. + # Use case such as building a child-class to StatelessLlama. + for base in bases: + if base is CompiledModule: + continue + base_exports = _all_compiled_module_class_infos[base].all_exports + for export_name in base_exports: + if export_name in info.all_exports: + continue + info.all_exports[export_name] = base_exports[export_name] + # Finish construction. new_class = type.__new__(mcls, name, bases, dct) _all_compiled_module_class_infos[new_class] = info diff --git a/python/shark_turbine/aot/support/procedural/globals.py b/python/shark_turbine/aot/support/procedural/globals.py index c186538d8..8bbbdee6e 100644 --- a/python/shark_turbine/aot/support/procedural/globals.py +++ b/python/shark_turbine/aot/support/procedural/globals.py @@ -241,6 +241,10 @@ def resolve_assignment(self, proc_trace: "IrTrace", ir_values: Sequence[Value]): with proc_trace.loc, proc_trace.ip: util_d.GlobalStoreOp(ir_values[0], self.symbol_name) + def set(self, other): + t = current_ir_trace() + self.resolve_assignment(t, super().set(other).ir_values) + def __repr__(self): return ( f"" diff --git a/python/shark_turbine/aot/support/procedural/primitives.py b/python/shark_turbine/aot/support/procedural/primitives.py index 0e07d8f48..e7fdc9419 100644 --- a/python/shark_turbine/aot/support/procedural/primitives.py +++ b/python/shark_turbine/aot/support/procedural/primitives.py @@ -68,6 +68,26 @@ class IrScalar(Intrinsic): def __init__(self, ir_type: IrType): self.ir_type = ir_type + def set(self, other): + t = current_ir_trace() + with t.ip, t.loc: + # Type check and promotion. + # TODO: Add more comprehensive type promotion hiearchy. + lhs = self.ir_value + rhs = None + if isinstance(other, IrScalar): + # Assumes when both are Value, they have same type. + rhs = other.ir_value + elif isinstance(other, (int, bool)) and _is_integer_like_type(self.ir_type): + rhs = arith_d.ConstantOp(lhs.type, other).result + elif isinstance(other, (float)) and _is_float_type(self.ir_type): + rhs = arith_d.ConstantOp(lhs.type, other).result + if rhs is None or lhs.type != rhs.type: + raise ValueError( + f"Cannot handle src type of {self.ir_type} to dst python type of {type(other)}." + ) + return IrImmediateScalar(rhs) + def __add__(self, other): t = current_ir_trace() with t.ip, t.loc: diff --git a/python/turbine_models/custom_models/README.md b/python/turbine_models/custom_models/README.md new file mode 100644 index 000000000..247354607 --- /dev/null +++ b/python/turbine_models/custom_models/README.md @@ -0,0 +1,39 @@ +# Instructions + +Clone and install SHARK-Turbine +``` +git clone https://github.com/nod-ai/SHARK-Turbine.git +cd SHARK-Turbine +python -m venv turbine_venv && source turbine_venv/bin/activate + +pip install --upgrade -r requirements.txt +pip install --upgrade -e .[torch-cpu-nightly,testing] +pip install --upgrade -r turbine-models-requirements.txt +``` + +## Compiling LLMs +Note: Make sure to replace "your_token" with your actual hf_auth_token for all the commands. + +Now, you can generate the quantized weight file with +``` +python python/turbine_models/gen_external_params/gen_external_params.py --hf_auth_token=your_token +``` +The model weights will then be saved in the current directory as `Llama_2_7b_chat_hf_f16_int4.safetensors`. + +To compile to vmfb for llama +``` +python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" +``` +By default the vmfb will be saved as `Llama_2_7b_chat_hf.vmfb`. + +## Running LLMs +There are two ways of running LLMs: + +1) Single run with predefined prompt to validate correctness. +``` +python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token +``` +2) Interactive CLI chat mode. (just add a --chat_mode flag) +``` +python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode +``` \ No newline at end of file diff --git a/python/turbine_models/custom_models/llm_optimizations/__init__.py b/python/turbine_models/custom_models/llm_optimizations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/turbine_models/custom_models/llm_optimizations/streaming_llm/README.md b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/README.md new file mode 100644 index 000000000..f212446ec --- /dev/null +++ b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/README.md @@ -0,0 +1,32 @@ +# StreamingLLM + +StreamingLLM is based on the paper *"Efficient Streaming Language Models with Attention Sinks"* by Xiao et al from the MIT Han Lab. Here is the original [[paper](http://arxiv.org/abs/2309.17453)] and [[code](https://github.com/mit-han-lab/streaming-llm)]. + +The modify_llama.py code is highly inspired by the modify_llama.py code in the original repo, but tweaked to work with ToM HuggingFace and compilable through Turbine. + +The work introduces sink attention which in short is a combination of a fixed starting few sequence attention along with a sliding window attention. This is beneficial for these reasons: + +1) Generate infinitely long context. +2) Maintain memory under certain threshold (controlled by window_length) + + +## Compiling LLMs with StreamingLLM + +Just need to add an extra `--streaming_llm` flag when you call stateless_llama when generating your vmfb. For example: +``` +python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" --streaming_llm +``` + +By default the vmfb will still be saved as `Llama_2_7b_chat_hf.vmfb`. + +## Running LLMs with StreamingLLM + +Similar to compiling, just need to add an extra `--streaming_llm` flag when you call llm_runner.py. For example: +``` +python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode --streaming_llm=true +``` + +## Future Work: +- [ ] Make window size configurable through python, everything is there but we'd need to initialize with a default value which would only be possible after we let `_create_initial_value` to take in initial value from GlobalAttribute somewhere [here](https://github.com/nod-ai/SHARK-Turbine/blob/18e8a4100b61adfd9425dd32f780dc5f90017813/python/shark_turbine/aot/support/ir_utils.py#L284-L316) . +- [ ] Get flow.move to enable overlap of sliding window and src of data. (Currently need to evict when it's at least 2x size of window) For example by default our streamingLLM window_size is 256, so we evict at ~600(slightly more than 2x for safety) token. +- [ ] Introduce Rerotation of RoPE to as seen [here](https://github.com/huggingface/transformers/blob/c2d283a64a7f33547952e3eb0fa6533fc375bcdd/src/transformers/cache_utils.py#L213-L218) to remove invasive modification of LlamaAttention module for streamingLLM. \ No newline at end of file diff --git a/python/turbine_models/custom_models/llm_optimizations/streaming_llm/__init__.py b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py new file mode 100644 index 000000000..a496b461f --- /dev/null +++ b/python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py @@ -0,0 +1,171 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import nn +import torch.utils.checkpoint + +import torch.nn.functional as F + +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + rotate_half, + apply_rotary_pos_emb, + repeat_kv, +) +import types + +__all__ = ["enable_llama_pos_shift_attention"] + + +def apply_rotary_pos_emb_single(x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +def llama_pos_shift_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = ( + self.num_key_value_heads * self.head_dim + ) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [ + F.linear(hidden_states, query_slices[i]) + for i in range(self.config.pretraining_tp) + ] + query_states = torch.cat(query_states, dim=-1) + + key_states = [ + F.linear(hidden_states, key_slices[i]) + for i in range(self.config.pretraining_tp) + ] + key_states = torch.cat(key_states, dim=-1) + + value_states = [ + F.linear(hidden_states, value_slices[i]) + for i in range(self.config.pretraining_tp) + ] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + ### Shift Pos: query pos is min(cache_size, idx) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) + ### + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + ### Shift Pos: key pos is the pos in cache + key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) + key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) + ### + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split( + self.hidden_size // self.config.pretraining_tp, dim=2 + ) + o_proj_slices = self.o_proj.weight.split( + self.hidden_size // self.config.pretraining_tp, dim=1 + ) + attn_output = sum( + [ + F.linear(attn_output[i], o_proj_slices[i]) + for i in range(self.config.pretraining_tp) + ] + ) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def enable_llama_pos_shift_attention(model): + for name, module in reversed(model._modules.items()): + if len(list(module.children())) > 0: + enable_llama_pos_shift_attention( + module, + ) + + if isinstance(module, LlamaAttention): + model._modules[name].forward = types.MethodType( + llama_pos_shift_attention_forward, model._modules[name] + ) diff --git a/python/turbine_models/custom_models/llm_runner.py b/python/turbine_models/custom_models/llm_runner.py index f3e84acc8..7632d1e65 100644 --- a/python/turbine_models/custom_models/llm_runner.py +++ b/python/turbine_models/custom_models/llm_runner.py @@ -3,6 +3,10 @@ from transformers import AutoTokenizer from iree import runtime as ireert import torch +import time +from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import ( + enable_llama_pos_shift_attention, +) parser = argparse.ArgumentParser() @@ -38,6 +42,11 @@ default="local-task", help="local-sync, local-task, cuda, vulkan, rocm", ) +parser.add_argument( + "--streaming_llm", + action="store_true", + help="Use streaming LLM mode for longer context and low memory usage.", +) parser.add_argument( "--prompt", type=str, @@ -46,42 +55,151 @@ """, help="prompt for llm model", ) +parser.add_argument( + "--chat_mode", + action="store_true", + help="Runs an interactive CLI chat mode.", +) +parser.add_argument( + "--chat_sys_prompt", + type=str, + default="""[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n +""", + help="System prompt used for interactive chat mode.", +) + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "", "" +DEFAULT_CHAT_SYS_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n <>\n\n +""" + + +def append_user_prompt(history, input_prompt): + user_prompt = f"{B_INST} {input_prompt} {E_INST}" + history += user_prompt + return history + + +def append_bot_prompt(history, input_prompt): + user_prompt = f"{B_SYS} {input_prompt}{E_SYS} {E_SYS}" + history += user_prompt + return history + + +class SharkLLM(object): + def __init__(self, device, vmfb_path, external_weight_path, streaming_llm=False): + self.runner = vmfbRunner( + device=device, + vmfb_path=vmfb_path, + external_weight_path=external_weight_path, + ) + if streaming_llm: + self.model = self.runner.ctx.modules.streaming_state_update + else: + self.model = self.runner.ctx.modules.state_update + self.first_input = True + self.num_tokens = 0 + self.last_prompt = None + self.streaming_llm = streaming_llm + self.prev_token_len = 0 + + def format_out(self, results): + return torch.tensor(results.to_host()[0][0]) + + def evict_kvcache_space(self): + self.model["evict_kvcache_space"]() + + def generate(self, input_ids): + # TODO: Replace with args. + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + turbine_results = [] + # Only need not seen token for init cache + # Because we have stored the res in KV-cache. + token_len = input_ids.shape[-1] + if self.streaming_llm: + token_slice = max(self.prev_token_len - 1, 0) + input_ids = input_ids[:, token_slice:] + inputs = [ireert.asdevicearray(self.runner.config.device, input_ids)] + if self.first_input or not self.streaming_llm: + s = time.time() + results = self.model["run_initialize"](*inputs) # example_input_id + e = time.time() + print( + f"num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}" + ) + token_len += 1 + self.first_input = False + else: + s = time.time() + results = self.model["run_cached_initialize"](*inputs) # example_input_id + e = time.time() + print( + f"Cached num_tokens: {token_len}, time_taken={e-s}, tok/second:{token_len/(e-s)}" + ) + token_len += 1 + s = time.time() + turbine_results.append(self.format_out(results)) + while self.format_out(results) != 2: + if self.streaming_llm and self.model["get_seq_step"]() > 600: + print("Evicting cache space!") + self.model["evict_kvcache_space"]() + results = self.model["run_forward"](results) + # uncomment to see tokens as they are emitted + # print(f"turbine: {tokenizer.decode(self.format_out(results))}") + turbine_results.append(self.format_out(results)) + e = time.time() + decoded_tokens = len(turbine_results) + print( + f"Decode num_tokens: {decoded_tokens}, time_taken={e-s}, tok/second:{decoded_tokens/(e-s)}" + ) + self.prev_token_len = token_len + decoded_tokens + return turbine_results def run_llm( - device, prompt, vmfb_path, hf_model_name, hf_auth_token, external_weight_path + device, + prompt, + vmfb_path, + hf_model_name, + hf_auth_token, + external_weight_path, + streaming_llm=False, + chat_mode=False, + chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT, ): - runner = vmfbRunner( - device=device, vmfb_path=vmfb_path, external_weight_path=external_weight_path - ) - tokenizer = AutoTokenizer.from_pretrained( hf_model_name, use_fast=False, token=hf_auth_token, ) - initial_input = tokenizer(prompt, return_tensors="pt") - example_input_id = initial_input.input_ids - inputs = [ireert.asdevicearray(runner.config.device, example_input_id)] - results = runner.ctx.modules.state_update["run_initialize"]( - *inputs - ) # example_input_id) - - def format_out(results): - return torch.tensor(results.to_host()[0][0]) - - turbine_results = [] - turbine_results.append(format_out(results)) - while format_out(results) != 2: - results = runner.ctx.modules.state_update["run_forward"](results) - # uncomment to see tokens as they are emitted - # print(f"turbine: {tokenizer.decode(format_out(results))}") - turbine_results.append(format_out(results)) - - return tokenizer.decode(turbine_results) + llm = SharkLLM( + device=device, + vmfb_path=vmfb_path, + external_weight_path=external_weight_path, + streaming_llm=streaming_llm, + ) + if not chat_mode: + initial_input = tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + turbine_results = llm.generate(example_input_id) + return tokenizer.decode(turbine_results) + prompt = chat_sys_prompt + while True: + user_prompt = input("User prompt: ") + prompt = append_user_prompt(prompt, user_prompt) + initial_input = tokenizer(prompt, return_tensors="pt") + example_input_id = initial_input.input_ids + result = llm.generate(example_input_id) + bot_response = tokenizer.decode(result, skip_special_tokens=True) + print(f"\nBOT: {bot_response}\n") + prompt = append_bot_prompt(prompt, bot_response) -def run_torch_llm(hf_model_name, hf_auth_token, prompt): +def run_torch_llm(hf_model_name, hf_auth_token, prompt, streaming_llm=False): from turbine_models.model_builder import HFTransformerBuilder from transformers import AutoModelForCausalLM @@ -93,6 +211,8 @@ def run_torch_llm(hf_model_name, hf_auth_token, prompt): auto_tokenizer=AutoTokenizer, ) model_builder.build_model() + if streaming_llm is True: + enable_llama_pos_shift_attention(model_builder.model) def get_token_from_logits(logits): return torch.argmax(logits[:, -1, :], dim=1) @@ -128,6 +248,9 @@ def get_token_from_logits(logits): args.hf_model_name, args.hf_auth_token, args.external_weight_path, + args.streaming_llm, + args.chat_mode, + args.chat_sys_prompt, ) print(turbine_output) if args.compare_vs_torch: diff --git a/python/turbine_models/custom_models/stateless_llama.py b/python/turbine_models/custom_models/stateless_llama.py index fcfb983f5..d51ad5041 100644 --- a/python/turbine_models/custom_models/stateless_llama.py +++ b/python/turbine_models/custom_models/stateless_llama.py @@ -8,6 +8,9 @@ from torch.utils import _pytree as pytree from shark_turbine.aot import * from iree.compiler.ir import Context +from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import ( + enable_llama_pos_shift_attention, +) from turbine_models.custom_models import remap_gguf import safetensors @@ -30,7 +33,9 @@ ) parser.add_argument("--quantization", type=str, default="unquantized") parser.add_argument("--external_weight_file", type=str, default="") -parser.add_argument("--vmfb_path", type=str, default="") +parser.add_argument( + "--vmfb_path", type=str, default=None, help="Path/name to store compiled vmfb." +) parser.add_argument( "--external_weights", type=str, @@ -40,7 +45,6 @@ parser.add_argument( "--precision", type=str, default="fp16", help="dtype of model [f16, f32]" ) - parser.add_argument( "--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm" ) @@ -52,6 +56,11 @@ help="Specify vulkan target triple or rocm/cuda target device.", ) parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") +parser.add_argument( + "--streaming_llm", + action="store_true", + help="Compile LLM with StreamingLLM optimizations", +) # TODO (Dan): replace this with a file once I figure out paths on windows exe json_schema = """ @@ -62,6 +71,8 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim): all_pkv_tensors = [] for i in range(heads * 2): + # Numpy semantic: sliced = global_pkv[i, 0, 0:seq_step, 0:heads, 0:hidden_dim] + # Generates tensor<1 x 1 x seq_step x heads x hidden_dim> sliced = IREE.tensor_slice( global_pkv, i, 0, (0, seq_step), (0, heads), (0, hidden_dim) ) # sequence context dim @@ -83,6 +94,8 @@ def export_transformer_model( device=None, target_triple=None, vulkan_max_allocation=None, + streaming_llm=False, + vmfb_path=None, ): state_schema = pytree.treespec_loads(json_schema) @@ -91,6 +104,8 @@ def export_transformer_model( torch_dtype=torch.float, token=hf_auth_token, ) + if streaming_llm: + enable_llama_pos_shift_attention(mod) dtype = torch.float32 if precision == "f16": mod = mod.half() @@ -200,8 +215,84 @@ def forward(token0: torch.Tensor, *state0_flat): token1 = token1[None, :] return token1, *state1_flat + class StreamingStateUpdateModule(StateUpdateModule): + def run_cached_initialize( + self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64) + ): + state_arg = slice_up_to_step( + self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM + ) + forw_const = ( + [x.dynamic_dim(1) < MAX_STEP_SEQ] + + [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + + [ + x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1)) + for x in state_arg[1:] + ] + + [x.dynamic_dim(1) < MAX_STEP_SEQ for x in state_arg[1:]] + ) + token, *state = self.cached_initialize( + x, *state_arg, constraints=forw_const + ) + len_of_new_tokens = IREE.tensor_dim( + state[0], 1 + ) # ? dimension of arbitrarily 0th kv tensor + for i in range(HEADS * 2): + slice_of_state = IREE.tensor_reshape( + state[i], 1, 1, len_of_new_tokens, HEADS, HIDDEN_DIM + ) + self.global_state = IREE.tensor_update( + self.global_state, slice_of_state, i, 0, self.global_seq_step, 0, 0 + ) + self.global_seq_step = self.global_seq_step + len_of_new_tokens + return token + + @jittable + def cached_initialize(input_ids, *state0_flat): + # Unpad the states. + cur_token_len = state0_flat[0].size(1) + state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat] + state0 = pytree.tree_unflatten(state0_flat, state_schema) + result = mod.forward(input_ids, past_key_values=state0) + state1_flat, _ = pytree.tree_flatten(result.past_key_values) + state1_flat = [ + torch.transpose(x[:, :, cur_token_len:, :], 1, 2) for x in state1_flat + ] + token1 = torch.argmax(result.logits[:, -1, :], dim=1) + token1 = token1[None, :] + return token1, *state1_flat + + # Streaming-LLM KVCache evict algorithm: + # slice1 = KVCache[0 : sink] + # slice2 = KVCache[seq_len - window_size : seq_len] + # KVCache = torch.cat([slice1, slice2]) + # TODO: Add move to handle overlap of data. + def evict_kvcache_space(self): + # TODO: Replace hardcoded with global variable. + sink_size = 4 + window_size = 252 + most_recent_window = self.global_seq_step + (-window_size) + for i in range(HEADS * 2): + update_window_state = IREE.tensor_slice( + self.global_state, + i, + 0, + (most_recent_window, window_size), + (0, HEADS), + (0, HIDDEN_DIM), + ) # sequence context dim + self.global_state = IREE.tensor_update( + self.global_state, update_window_state, i, 0, sink_size, 0, 0 + ) + self.global_seq_step.set(window_size + sink_size) + return self.global_seq_step + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" - inst = StateUpdateModule(context=Context(), import_to=import_to) + if streaming_llm: + print("Compiling with Streaming LLM") + inst = StreamingStateUpdateModule(context=Context(), import_to=import_to) + else: + inst = StateUpdateModule(context=Context(), import_to=import_to) # TODO: Integrate with external parameters to actually be able to run # TODO: Make more generalizable to be able to quantize with all compile_to options if quantization == "int4" and not compile_to == "linalg": @@ -266,7 +357,9 @@ def forward(token0: torch.Tensor, *state0_flat): target_backends=[device], extra_args=flags, ) - with open(f"{safe_name}.vmfb", "wb+") as f: + if vmfb_path is None: + vmfb_path = f"{safe_name}.vmfb" + with open(vmfb_path, "wb+") as f: f.write(flatbuffer_blob) print("saved to ", safe_name + ".vmfb") return module_str, tokenizer @@ -285,6 +378,8 @@ def forward(token0: torch.Tensor, *state0_flat): args.device, args.iree_target_triple, args.vulkan_max_allocation, + args.streaming_llm, + args.vmfb_path, ) safe_name = args.hf_model_name.split("/")[-1].strip() safe_name = re.sub("-", "_", safe_name) diff --git a/python/turbine_models/tests/stateless_llama_test.py b/python/turbine_models/tests/stateless_llama_test.py index 081c31b71..fc5bc9cd2 100644 --- a/python/turbine_models/tests/stateless_llama_test.py +++ b/python/turbine_models/tests/stateless_llama_test.py @@ -7,103 +7,147 @@ import logging import turbine_models.custom_models.stateless_llama as llama import os -import pytest - -from typing import Literal - - -import os -import sys -import re - -from typing import Tuple +import unittest +import difflib os.environ["TORCH_LOGS"] = "dynamic" -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -from torch.utils import _pytree as pytree from shark_turbine.aot import * -from iree.compiler.ir import Context -from iree import runtime as ireert - -from turbine_models.custom_models import remap_gguf -import safetensors - -from tqdm import tqdm - - -def test_vmfb_comparison(): - """ - Test that the vmfb model produces the same output as the torch model - - Precision can be 16 or 32, using 16 for speed and memory. +from turbine_models.custom_models import llm_runner + +from turbine_models.gen_external_params.gen_external_params import ( + gen_external_params, +) + +quantization = "unquantized" +precision = "f32" +gen_external_params( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + quantization=quantization, + hf_auth_token=None, + precision=precision, +) +DEFAULT_PROMPT = """[INST] <> +Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] +""" - For VMFB, quantization can be int4 or None, but right now only using none for compatibility with torch. - """ - quantization = "unquantized" - precision = "f32" - llama.export_transformer_model( - hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - hf_auth_token=None, - compile_to="vmfb", - external_weights="safetensors", - # external_weight_file="Llama-2-7b-chat-hf-function-calling-v2_f16_int4.safetensors", Do not export weights because this doesn't get quantized - quantization=quantization, - precision=precision, - device="llvm-cpu", - target_triple="host", +def check_output_string(reference, output): + # Calculate and print diff + diff = difflib.unified_diff( + reference.splitlines(keepends=True), + output.splitlines(keepends=True), + fromfile="reference", + tofile="output", + lineterm="", ) + assert reference == output, "".join(diff) - from turbine_models.gen_external_params.gen_external_params import ( - gen_external_params, - ) - gen_external_params( - hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", - quantization=quantization, - hf_auth_token=None, - precision=precision, - ) +class StatelessLlamaChecks(unittest.TestCase): + def test_vmfb_comparison(self): + """ + Test that the vmfb model produces the same output as the torch model - DEFAULT_PROMPT = """[INST] <> -Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <> hi what are you? [/INST] -""" + Precision can be 16 or 32, using 16 for speed and memory. - torch_str_cache_path = f"python/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" - # if cached, just read - if os.path.exists(torch_str_cache_path): - with open(torch_str_cache_path, "r") as f: - torch_str = f.read() - else: - from turbine_models.custom_models import llm_runner + For VMFB, quantization can be int4 or None, but right now only using none for compatibility with torch. + """ - torch_str = llm_runner.run_torch_llm( - "Trelis/Llama-2-7b-chat-hf-function-calling-v2", None, DEFAULT_PROMPT + llama.export_transformer_model( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + # external_weight_file="Llama-2-7b-chat-hf-function-calling-v2_f16_int4.safetensors", Do not export weights because this doesn't get quantized + quantization=quantization, + precision=precision, + device="llvm-cpu", + target_triple="host", ) - with open(torch_str_cache_path, "w") as f: - f.write(torch_str) + torch_str_cache_path = f"python/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + # if cached, just read + if os.path.exists(torch_str_cache_path): + with open(torch_str_cache_path, "r") as f: + torch_str = f.read() + else: + torch_str = llm_runner.run_torch_llm( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", None, DEFAULT_PROMPT + ) + + with open(torch_str_cache_path, "w") as f: + f.write(torch_str) + + turbine_str = llm_runner.run_llm( + "local-task", + DEFAULT_PROMPT, + "Llama_2_7b_chat_hf_function_calling_v2.vmfb", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", + ) + check_output_string(torch_str, turbine_str) + + def test_streaming_vmfb_comparison(self): + """ + Similar test to above but for streaming-LLM. + """ + llama.export_transformer_model( + hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + # external_weight_file="Llama-2-7b-chat-hf-function-calling-v2_f16_int4.safetensors", Do not export weights because this doesn't get quantized + quantization=quantization, + precision=precision, + device="llvm-cpu", + target_triple="host", + streaming_llm=True, + vmfb_path="streaming_llama.vmfb", + ) - from turbine_models.custom_models import llm_runner + torch_str_cache_path = f"python/turbine_models/tests/vmfb_comparison_cached_torch_output_{precision}_{quantization}.txt" + # if cached, just read + if os.path.exists(torch_str_cache_path): + with open(torch_str_cache_path, "r") as f: + torch_str = f.read() + else: + torch_str = llm_runner.run_torch_llm( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + DEFAULT_PROMPT, + streaming_llm=True, + ) + + with open(torch_str_cache_path, "w") as f: + f.write(torch_str) + + turbine_str = llm_runner.run_llm( + "local-task", + DEFAULT_PROMPT, + "streaming_llama.vmfb", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", + streaming_llm=True, + ) + check_output_string(torch_str, turbine_str) - turbine_str = llm_runner.run_llm( - "local-task", - DEFAULT_PROMPT, - "Llama_2_7b_chat_hf_function_calling_v2.vmfb", - "Trelis/Llama-2-7b-chat-hf-function-calling-v2", - None, - f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors", - ) + def test_rerotated_torch_comparison(self): + torch_str = llm_runner.run_torch_llm( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + DEFAULT_PROMPT, + ) + rotated_torch_str = llm_runner.run_torch_llm( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + None, + DEFAULT_PROMPT, + streaming_llm=True, + ) + check_output_string(torch_str, rotated_torch_str) - import difflib - # Calculate and print diff - diff = difflib.unified_diff( - torch_str.splitlines(keepends=True), - turbine_str.splitlines(keepends=True), - fromfile="torch_str", - tofile="turbine_str", - lineterm="", - ) - assert torch_str == turbine_str, "".join(diff) +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index fa343b36d..11618155f 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -194,6 +194,67 @@ def read(self): self.assertIn("@_state_i64.global {noinline} = 0 : i64", module_str) self.assertIn("@_state_bool.global {noinline} = false", module_str) + def testInheritExportScalars(self): + class BaseState(CompiledModule): + state_index = export_global(AbstractIndex, mutable=True) + state_f32 = export_global(AbstractF32, mutable=True) + + def read(self): + return (self.state_index, self.state_f32) + + class DerivedState(BaseState): + pass + + inst = DerivedState(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("@_state_index.global {noinline} = 0 : index", module_str) + self.assertIn("@_state_f32.global {noinline} = 0.000000e+00 : f32", module_str) + self.assertIn( + "return %_state_index.global, %_state_f32.global : index, f32", module_str + ) + + def testInheritOverrideBase(self): + class BaseState(CompiledModule): + state_index = export_global(AbstractIndex, mutable=True) + state_f32 = export_global(AbstractF32, mutable=True) + + def read(self): + return (self.state_index, self.state_f32) + + class DerivedState(BaseState): + def read(self): + return self.state_index + + inst = DerivedState(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("@_state_index.global {noinline} = 0 : index", module_str) + self.assertNotIn( + "@_state_f32.global {noinline} = 0.000000e+00 : f32", module_str + ) + self.assertIn("return %_state_index.global : index", module_str) + + def testInheritExportModules(self): + m = SimpleParams() + + class BaseModule(CompiledModule): + params = export_parameters(m, mutable=True) + + def update_params(me, updates=abstractify(params)): + self.assertIn("classifier.weight", updates) + self.assertIn("classifier.bias", updates) + me.params = updates + + class DerivedModule(BaseModule): + pass + + inst = DerivedModule(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str) + self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str) + def testUpdateGlobalStateTree(self): state_example = { "data": torch.randn(3, 11), diff --git a/tests/aot/iree_procedural_test.py b/tests/aot/iree_procedural_test.py index d541301be..9f4799210 100644 --- a/tests/aot/iree_procedural_test.py +++ b/tests/aot/iree_procedural_test.py @@ -251,6 +251,22 @@ def foobar(self, a=AbstractI32): self.assertIn("%cst = arith.constant 3.230000e+00 : f32", module_str) self.assertIn("arith.addf %0, %cst : f32", module_str) + def testSetScalarState(self): + class ArithModule(CompiledModule): + state_index = export_global(AbstractIndex, mutable=True) + state_f32 = export_global(AbstractF32, mutable=True) + + def foobar(self): + self.state_index.set(5) + self.state_f32.set(5.5) + + inst = ArithModule(context=Context(), import_to=None) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn("util.global.store %c5, @_state_index.global : index", module_str) + self.assertIn("%cst = arith.constant 5.500000e+00 : f32", module_str) + self.assertIn("util.global.store %cst, @_state_f32.global", module_str) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)