-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Stateless Llama] StreamingLLM + Add KVCache for prefill stage + inte…
…ractive chat mode in llm_runner. (#299) This PR introduce streamingLLM + KV-Cache at initialization/prefill stage functionality, this will allow us to generate infinite tokens under controlled memory growh. This PR also introduce: 1.Set capabilities of GlobalScalars 2.Inheritance of exports/globals for CompiledModule subclasses. 3.READMEs for llm_runner and stateless_llama 4.e2e test refactoring
- Loading branch information
1 parent
18e8a41
commit 432fa0d
Showing
14 changed files
with
736 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,10 @@ _python_build/ | |
dist/ | ||
wheelhouse | ||
*.egg-info | ||
*.whl | ||
*.whl | ||
|
||
#Model artifacts | ||
*.pt | ||
*.safetensors | ||
*.gguf | ||
*.vmfb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Instructions | ||
|
||
Clone and install SHARK-Turbine | ||
``` | ||
git clone https://github.com/nod-ai/SHARK-Turbine.git | ||
cd SHARK-Turbine | ||
python -m venv turbine_venv && source turbine_venv/bin/activate | ||
pip install --upgrade -r requirements.txt | ||
pip install --upgrade -e .[torch-cpu-nightly,testing] | ||
pip install --upgrade -r turbine-models-requirements.txt | ||
``` | ||
|
||
## Compiling LLMs | ||
Note: Make sure to replace "your_token" with your actual hf_auth_token for all the commands. | ||
|
||
Now, you can generate the quantized weight file with | ||
``` | ||
python python/turbine_models/gen_external_params/gen_external_params.py --hf_auth_token=your_token | ||
``` | ||
The model weights will then be saved in the current directory as `Llama_2_7b_chat_hf_f16_int4.safetensors`. | ||
|
||
To compile to vmfb for llama | ||
``` | ||
python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" | ||
``` | ||
By default the vmfb will be saved as `Llama_2_7b_chat_hf.vmfb`. | ||
|
||
## Running LLMs | ||
There are two ways of running LLMs: | ||
|
||
1) Single run with predefined prompt to validate correctness. | ||
``` | ||
python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token | ||
``` | ||
2) Interactive CLI chat mode. (just add a --chat_mode flag) | ||
``` | ||
python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode | ||
``` |
Empty file.
32 changes: 32 additions & 0 deletions
32
python/turbine_models/custom_models/llm_optimizations/streaming_llm/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# StreamingLLM | ||
|
||
StreamingLLM is based on the paper *"Efficient Streaming Language Models with Attention Sinks"* by Xiao et al from the MIT Han Lab. Here is the original [[paper](http://arxiv.org/abs/2309.17453)] and [[code](https://github.com/mit-han-lab/streaming-llm)]. | ||
|
||
The modify_llama.py code is highly inspired by the modify_llama.py code in the original repo, but tweaked to work with ToM HuggingFace and compilable through Turbine. | ||
|
||
The work introduces sink attention which in short is a combination of a fixed starting few sequence attention along with a sliding window attention. This is beneficial for these reasons: | ||
|
||
1) Generate infinitely long context. | ||
2) Maintain memory under certain threshold (controlled by window_length) | ||
|
||
|
||
## Compiling LLMs with StreamingLLM | ||
|
||
Just need to add an extra `--streaming_llm` flag when you call stateless_llama when generating your vmfb. For example: | ||
``` | ||
python python/turbine_models/custom_models/stateless_llama.py --compile_to=vmfb --hf_auth_token=your_token --external_weights="safetensors" --quantization="int4" --precision="f16" --streaming_llm | ||
``` | ||
|
||
By default the vmfb will still be saved as `Llama_2_7b_chat_hf.vmfb`. | ||
|
||
## Running LLMs with StreamingLLM | ||
|
||
Similar to compiling, just need to add an extra `--streaming_llm` flag when you call llm_runner.py. For example: | ||
``` | ||
python python/turbine_models/custom_models/llm_runner.py --vmfb_path=/path/to/Llama_2_7b_chat_hf.vmfb --external_weight_path=Llama_2_7b_chat_hf_f16_int4.safetensors --device=vulkan hf_auth_token=your_hf_token --chat_mode --streaming_llm=true | ||
``` | ||
|
||
## Future Work: | ||
- [ ] Make window size configurable through python, everything is there but we'd need to initialize with a default value which would only be possible after we let `_create_initial_value` to take in initial value from GlobalAttribute somewhere [here](https://github.com/nod-ai/SHARK-Turbine/blob/18e8a4100b61adfd9425dd32f780dc5f90017813/python/shark_turbine/aot/support/ir_utils.py#L284-L316) . | ||
- [ ] Get flow.move to enable overlap of sliding window and src of data. (Currently need to evict when it's at least 2x size of window) For example by default our streamingLLM window_size is 256, so we evict at ~600(slightly more than 2x for safety) token. | ||
- [ ] Introduce Rerotation of RoPE to as seen [here](https://github.com/huggingface/transformers/blob/c2d283a64a7f33547952e3eb0fa6533fc375bcdd/src/transformers/cache_utils.py#L213-L218) to remove invasive modification of LlamaAttention module for streamingLLM. |
Empty file.
171 changes: 171 additions & 0 deletions
171
python/turbine_models/custom_models/llm_optimizations/streaming_llm/modify_llama.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import math | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
import torch.utils.checkpoint | ||
|
||
import torch.nn.functional as F | ||
|
||
from transformers.models.llama.modeling_llama import ( | ||
LlamaAttention, | ||
rotate_half, | ||
apply_rotary_pos_emb, | ||
repeat_kv, | ||
) | ||
import types | ||
|
||
__all__ = ["enable_llama_pos_shift_attention"] | ||
|
||
|
||
def apply_rotary_pos_emb_single(x, cos, sin, position_ids): | ||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | ||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
x_embed = (x * cos) + (rotate_half(x) * sin) | ||
return x_embed | ||
|
||
|
||
def llama_pos_shift_attention_forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
bsz, q_len, _ = hidden_states.size() | ||
|
||
if self.config.pretraining_tp > 1: | ||
key_value_slicing = ( | ||
self.num_key_value_heads * self.head_dim | ||
) // self.config.pretraining_tp | ||
query_slices = self.q_proj.weight.split( | ||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 | ||
) | ||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) | ||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) | ||
|
||
query_states = [ | ||
F.linear(hidden_states, query_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
query_states = torch.cat(query_states, dim=-1) | ||
|
||
key_states = [ | ||
F.linear(hidden_states, key_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
key_states = torch.cat(key_states, dim=-1) | ||
|
||
value_states = [ | ||
F.linear(hidden_states, value_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
value_states = torch.cat(value_states, dim=-1) | ||
|
||
else: | ||
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
|
||
query_states = query_states.view( | ||
bsz, q_len, self.num_heads, self.head_dim | ||
).transpose(1, 2) | ||
key_states = key_states.view( | ||
bsz, q_len, self.num_key_value_heads, self.head_dim | ||
).transpose(1, 2) | ||
value_states = value_states.view( | ||
bsz, q_len, self.num_key_value_heads, self.head_dim | ||
).transpose(1, 2) | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
### Shift Pos: query pos is min(cache_size, idx) | ||
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) | ||
### | ||
|
||
if past_key_value is not None: | ||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models | ||
key_states, value_states = past_key_value.update( | ||
key_states, value_states, self.layer_idx, cache_kwargs | ||
) | ||
|
||
### Shift Pos: key pos is the pos in cache | ||
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) | ||
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) | ||
### | ||
|
||
# repeat k/v heads if n_kv_heads < n_heads | ||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( | ||
self.head_dim | ||
) | ||
|
||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" | ||
f" {attn_weights.size()}" | ||
) | ||
|
||
if attention_mask is not None: | ||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | ||
) | ||
attn_weights = attn_weights + attention_mask | ||
|
||
# upcast attention to fp32 | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( | ||
query_states.dtype | ||
) | ||
attn_output = torch.matmul(attn_weights, value_states) | ||
|
||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||
raise ValueError( | ||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | ||
f" {attn_output.size()}" | ||
) | ||
|
||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
if self.config.pretraining_tp > 1: | ||
attn_output = attn_output.split( | ||
self.hidden_size // self.config.pretraining_tp, dim=2 | ||
) | ||
o_proj_slices = self.o_proj.weight.split( | ||
self.hidden_size // self.config.pretraining_tp, dim=1 | ||
) | ||
attn_output = sum( | ||
[ | ||
F.linear(attn_output[i], o_proj_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
) | ||
else: | ||
attn_output = self.o_proj(attn_output) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, past_key_value | ||
|
||
|
||
def enable_llama_pos_shift_attention(model): | ||
for name, module in reversed(model._modules.items()): | ||
if len(list(module.children())) > 0: | ||
enable_llama_pos_shift_attention( | ||
module, | ||
) | ||
|
||
if isinstance(module, LlamaAttention): | ||
model._modules[name].forward = types.MethodType( | ||
llama_pos_shift_attention_forward, model._modules[name] | ||
) |
Oops, something went wrong.