Skip to content

Commit

Permalink
Llama adapter (#983)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
Co-authored-by: regisss <[email protected]>
  • Loading branch information
sywangyi and regisss authored Jul 15, 2024
1 parent b05584a commit 609e450
Show file tree
Hide file tree
Showing 10 changed files with 289 additions and 37 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ fast_test_videomae:
# Run single-card non-regression tests
slow_tests_1x: test_installs
python -m pytest tests/test_examples.py -v -s -k "single_card"
python -m pip install peft==0.10.0
python -m pytest tests/test_peft_inference.py
python -m pytest tests/test_pipeline.py

# Run multi-card non-regression tests
Expand Down
4 changes: 2 additions & 2 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ python run_clm.py \

## PEFT

### LORA/ADALORA/IA3
### LORA/ADALORA/IA3/LLAMA_ADAPTER

To run LoRA finetuning, you can use `run_lora_clm.py`.
Here are single-/multi-device command examples for Llama1-7B, Falcon-40B, Llama2-70B, Llama3-8B and Llama3-70B.
Expand Down Expand Up @@ -691,7 +691,7 @@ DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 ..
--validation_split_percentage 5 \
--deepspeed ds_falcon_180b_z3.json
```
Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`.
Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`, or enable llama-adapter for llama model using `--peft_type llama-adapter`.

### Prompt/Prefix/P-tuning

Expand Down
25 changes: 23 additions & 2 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch
import transformers
from datasets import load_dataset
from peft import AdaLoraConfig, IA3Config, LoraConfig, TaskType, get_peft_model, tuners
from peft import AdaLoraConfig, AdaptionPromptConfig, IA3Config, LoraConfig, TaskType, get_peft_model, tuners
from peft.utils.other import fsdp_auto_wrap_policy
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -338,7 +338,7 @@ class FinetuneArguments:
default="lora",
metadata={
"help": ("The PEFT type to use."),
"choices": ["lora", "ia3", "adalora"],
"choices": ["lora", "ia3", "adalora", "llama-adapter"],
},
)
ia3_target_modules: List[str] = field(
Expand All @@ -349,6 +349,14 @@ class FinetuneArguments:
default_factory=lambda: None,
metadata={"help": "Target feedforward modules for the IA3 method."},
)
adapter_layers: int = field(
default=30,
metadata={"help": "Number of adapter layers (from the top) in llama-adapter"},
)
adapter_len: int = field(
default=10,
metadata={"help": "Number of adapter tokens to insert in llama-adapter"},
)


PROMPT_DICT = {
Expand Down Expand Up @@ -785,6 +793,19 @@ def compute_metrics(eval_preds):
feedforward_modules=finetune_args.feedforward_modules,
task_type=TaskType.CAUSAL_LM,
)
elif finetune_args.peft_type == "llama-adapter":
peft_config = AdaptionPromptConfig(
adapter_layers=finetune_args.adapter_layers,
adapter_len=finetune_args.adapter_len,
task_type=TaskType.CAUSAL_LM,
)
from optimum.habana.peft.layer import (
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)

tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
lora_model = get_peft_model(model, peft_config)
Expand Down
13 changes: 13 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def setup_model(args, model_dtype, model_kwargs, logger):
assistant_model = wrap_in_hpu_graph(assistant_model)
if _is_peft_model(model):
model.base_model = wrap_in_hpu_graph(model.base_model)
if model.peft_type == "ADAPTION_PROMPT":
model.base_model.model = wrap_in_hpu_graph(model.base_model.model)

if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)
Expand Down Expand Up @@ -372,6 +374,17 @@ def peft_model(args, model_dtype, logger, **model_kwargs):

model.__class__.generate = gaudi_generate
model.__class__.prepare_inputs_for_generation = gaudi_prepare_inputs_for_generation
if model.peft_type == "ADAPTION_PROMPT":
from peft import tuners

from optimum.habana.peft.layer import (
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)

tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr

return model


Expand Down
6 changes: 5 additions & 1 deletion optimum/habana/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from .layer import GaudiAdaloraLayerSVDLinearForward
from .layer import (
GaudiAdaloraLayerSVDLinearForward,
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)
from .peft_model import gaudi_generate, gaudi_prepare_inputs_for_generation
144 changes: 144 additions & 0 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import inspect
import math
from typing import Any

import torch
import torch.nn.functional as F
from peft.tuners.adaption_prompt.config import TRANSFORMERS_MODEL_CONFIG
from peft.tuners.adaption_prompt.utils import llama_apply_rotary_pos_emb, llama_rotate_half


def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
Expand Down Expand Up @@ -31,3 +36,142 @@ def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwarg
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * (scaling / ranknum)

return result


def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor:
"""
Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60
The only differences are:
-add reuse cache support.
-add past key value list support
"""
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)

seq_len = q_len

if past_key_value is not None:
if kwargs.get("reuse_cache", False):
seq_len += past_key_value[0][-2]
elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)

# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)

rotary_emb_kwargs = {"position_ids": position_ids}
# The `seq_len` argument has been officially removed in transformers >= 4.39.0
if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens

cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)

# For batched inference unsqueeze it on the correct dim
# since: https://github.com/huggingface/transformers/pull/29109
if len(cos.shape) == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

return (query_states * cos) + (llama_rotate_half(query_states) * sin)


def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs):
"""
Copied from AdaptedAttention.forward: https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/layer.py#L57
The only differences are:
- replace self.model() with self.model.pre_attn_forward()
"""
if kwargs.get("output_attention", False):
raise NotImplementedError("output_attention is not currently supported.")

output, _, past_key_value = self.model.pre_attn_forward(*args, **kwargs)
bsz = output.shape[0]
q_len = output.shape[1]
embed_dim = output.shape[2]
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
factor = (
self.model.k_proj.in_features // self.model.k_proj.out_features
) # Mistral has different input and output dimension for k_proj and v_proj layers

if k_proj_layer == v_proj_layer:
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
else:
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)

# (bsz, num_key_value_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
adapter_v = (
value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
# Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181
# (bsz, num_heads, adapter_len, head_dim)
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
# Recompute query states.
# (bsz, num_heads, q_len, head_dim)
query_states = compute_query_states(model=self.model, **kwargs)

previous_dtype = query_states.dtype

# (bsz, num_heads, q_len, adapter_len)
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(self.model.head_dim)
# Upcast attention to fp32
# (bsz, num_heads, q_len, adapter_len)
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
# (bsz, q_len, num_heads * head_dim)
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)

# (bsz, q_len, hidden_size)
if o_proj_layer is not None:
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)

# Add adaption prompt output to original output.
output = output + adapter_output

# Restore original dtype.
output = output.to(previous_dtype)
return output, None, past_key_value


def GaudiAdaptedAttention_getattr(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super(self.__class__, self).__getattr__(name)
except AttributeError:
# This is necessary as e.g. causal models have various methods that we
# don't want to re-implement here.
return getattr(self.model, name)
28 changes: 14 additions & 14 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,20 +656,20 @@ def pre_attn(
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down
38 changes: 38 additions & 0 deletions tests/baselines/llama_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,44 @@
}
}
},
"llama-adapter": {
"num_train_epochs": 3,
"eval_batch_size": 4,
"distribution": {
"multi_card": {
"learning_rate": 3e-4,
"train_batch_size": 8,
"perplexity": 5.575,
"train_runtime": 131.7,
"train_samples_per_second": 294,
"extra_arguments": [
"--bf16",
"--gradient_accumulation_steps 2",
"--evaluation_strategy no",
"--save_strategy no",
"--warmup_ratio 0.03",
"--lr_scheduler_type constant",
"--max_grad_norm 0.3",
"--logging_steps 1",
"--use_hpu_graphs_for_inference",
"--lora_rank 8",
"--lora_alpha 16",
"--lora_dropout 0.05",
"--lora_target_modules q_proj v_proj",
"--dataset_concatenation",
"--max_seq_length 512",
"--low_cpu_mem_usage True",
"--adam_epsilon 1e-08",
"--ddp_bucket_cap_mb 50",
"--validation_split_percentage 10",
"--attn_softmax_bf16",
"--adapter_layers 2",
"--adapter_len 4",
"--peft_type llama-adapter"
]
}
}
},
"trl-sft": {
"num_train_epochs": 1,
"eval_batch_size": 1,
Expand Down
Loading

0 comments on commit 609e450

Please sign in to comment.