Skip to content

Commit

Permalink
[Inference]Adapt to baichuan2 13B (hpcaitech#5614)
Browse files Browse the repository at this point in the history
* adapt to baichuan2 13B

* adapt to baichuan2 13B

* change BAICHUAN_MODEL_NAME_OR_PATH

* fix test_decoding_attn.py

* Modifications based on review comments.

* change BAICHUAN_MODEL_NAME_OR_PATH

* mv attn mask processes to test flash decoding

* mv get_alibi_slopes baichuan modeling

* fix bugs in test_baichuan.py
  • Loading branch information
isky-cd authored Apr 25, 2024
1 parent f342a93 commit 3c91e3f
Show file tree
Hide file tree
Showing 10 changed files with 786 additions and 134 deletions.
1 change: 1 addition & 0 deletions colossalai/inference/flash_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ def initialize(
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)

self._tensors_initialized = True
9 changes: 8 additions & 1 deletion colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = get_model_config_attr(model_config, "num_hidden_layers")
self.head_num = get_model_config_attr(model_config, "num_attention_heads")
self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads")
self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num

if hasattr(config, "num_key_value_heads"):
self.kv_head_num = getattr(config, "num_key_value_heads")
elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]):
self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"])
else:
self.kv_head_num = self.head_num

assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
Expand Down
208 changes: 186 additions & 22 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,83 @@
# This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn

from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
flash_decoding_attention,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger

logger = get_dist_logger(__name__)

try:
from flash_attn import flash_attn_varlen_func

use_flash_attn2 = True
except ImportError:
use_flash_attn2 = False
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")

inference_ops = InferenceOpsLoader().load()

logger = get_dist_logger(__name__)


# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes


def baichuan_rmsnorm_forward(
self,
hidden_states: torch.Tensor,
norm_output: torch.Tensor,
residual: torch.Tensor = None,
use_cuda_kernel: bool = True,
):
# Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b.
if hasattr(self, "variance_epsilon"):
eps = self.variance_epsilon
elif hasattr(self, "epsilon"):
eps = self.epsilon
else:
TypeError(
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
)

if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
return hidden_states, residual

if norm_output is None:
norm_output = torch.empty_like(hidden_states)
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps)
return norm_output, hidden_states
else:
return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual)


class NopadBaichuanAttention(nn.Module):
def __init__(
self,
Expand All @@ -39,9 +103,11 @@ def __init__(
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads

# Used to adapt llama_base_attn_forward
self.num_key_value_heads = self.num_heads
self.alibi_slopes = None
self.use_alibi_attn = False
if self.hidden_size == 5120:
self.use_alibi_attn = True
self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device)

qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
Expand Down Expand Up @@ -112,26 +178,124 @@ def forward(
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""

return NopadLlamaAttention.forward(
self,
hidden_states=hidden_states,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
sequence_lengths=sequence_lengths,
cos_sin=cos_sin,
fd_inter_tensor=fd_inter_tensor,
is_prompts=is_prompts,
is_verifier=is_verifier,
tokens_to_verify=tokens_to_verify,
kv_seq_len=kv_seq_len,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
cu_seqlens=cu_seqlens,
high_precision=high_precision,
token_nums = hidden_states.size(0)
# fused qkv
hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)

block_size = k_cache.size(-2)

if is_prompts:
if (
not is_verifier
and use_cuda_kernel
and query_states.dtype != torch.float32
and use_flash_attn2
and not self.use_alibi_attn
):
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=kv_seq_len,
max_seqlen_k=kv_seq_len,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
attn_output = attn_output.view(token_nums, -1)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1

if use_cuda_kernel:
if not self.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
sequence_lengths,
block_tables,
high_precision,
)
else:
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
else:
if not self.use_alibi_attn:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
copy_k_to_blocked_cache(
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)
copy_k_to_blocked_cache(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)

attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)

return attn_output


# NOTE This will cause difference as out length increases.
class NopadBaichuanMLP(nn.Module):
Expand Down
47 changes: 26 additions & 21 deletions colossalai/inference/modeling/policy/nopadding_baichuan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch.nn as nn
from torch.nn import Parameter

from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP
from colossalai.inference.modeling.models.nopadding_baichuan import (
NopadBaichuanAttention,
NopadBaichuanMLP,
baichuan_rmsnorm_forward,
)
from colossalai.inference.modeling.models.nopadding_llama import (
llama_causal_lm_forward,
llama_decoder_layer_forward,
llama_model_forward,
llama_rmsnorm_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
Expand All @@ -21,38 +24,40 @@ def module_policy(self):
policy = super().module_policy()

decoder_attribute_replacement = {
"lm_head.weight": Parameter(
nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False
),
"lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False),
}
policy["BaichuanForCausalLM"] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)

policy["DecoderLayer"] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
)
# used for relpacing Baichuan 7B/13B decoder layer
for layer_name in ["DecoderLayer", "BaichuanLayer"]:
policy[layer_name] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="mlp",
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn",
target_module=NopadBaichuanAttention,
),
]
)

self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name
)

self.append_or_create_method_replacement(
description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM"
)
self.append_or_create_method_replacement(
description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel"
)

self.append_or_create_method_replacement(
description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer"
)
self.append_or_create_method_replacement(
description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm"
)

return policy
Expand Down
Loading

0 comments on commit 3c91e3f

Please sign in to comment.