Skip to content

Commit

Permalink
Add support for Baichuan2
Browse files Browse the repository at this point in the history
Below is an example for baichuan-inc/Baichuan2-7B-Chat:
python3 run_generation.py \
--model_name_or_path baichuan-inc/Baichuan2-7B-Chat \
--bf16 --trim_logits --batch_size 1 \
--max_input_tokens 1024 --max_new_tokens 512 \
--use_kv_cache --use_hpu_graphs --use_flash_attention \
--reuse_cache \
--no-ignore_eos

Below is an example for baichuan-inc/Baichuan2-13B-Chat:
python3 run_generation.py \
--model_name_or_path baichuan-inc/Baichuan2-13B-Chat \
--bf16 --trim_logits --batch_size 1 \
--max_input_tokens 1024 --max_new_tokens 512 \
--use_kv_cache --use_hpu_graphs --bucket_size 256 \
--bucket_internal --reuse_cache \
--no-ignore_eos

Co-authored-by: Jianqian Zhou <[email protected]>
Co-authored-by: Wei Lin <[email protected]>
Signed-off-by: Haihao Xiang <[email protected]>
  • Loading branch information
3 people committed Nov 12, 2024
1 parent 7ac1db1 commit 952eb48
Show file tree
Hide file tree
Showing 14 changed files with 2,167 additions and 10 deletions.
8 changes: 5 additions & 3 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,11 @@ def main():

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# We need to skip this test for baichuan pretrain
if config.model_type not in ("baichuan"):
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
3 changes: 2 additions & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,14 @@ def __init__(self, tokenizer, model, args, options):
"gptj",
"starcoder2",
"gemma",
"baichuan",
]:
self.model_inputs.update(
{
"reuse_cache": self.options.reuse_cache,
}
)
if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon", "starcoder2", "gemma"]:
if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon", "starcoder2", "gemma", "baichuan"]:
if self.model.config.model_type != "falcon":
self.model_inputs.update(
{
Expand Down
7 changes: 5 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"qwen2_moe",
"whisper",
"idefics2",
"baichuan",
]


Expand Down Expand Up @@ -1057,8 +1058,9 @@ def generate(
"starcoder2",
"qwen2_moe",
"gemma",
"baichuan",
]
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma and starcoder2 at the moment"
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, starcoder2 and baichuan at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
Expand Down Expand Up @@ -1253,8 +1255,9 @@ def generate(
"starcoder2",
"gemma",
"qwen2_moe",
"baichuan",
]:
if self.config.max_position_embeddings < calculated_max_length:
if hasattr(self.config, "max_position_embeddings") and self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)

# 8. determine generation mode
Expand Down
7 changes: 7 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@
LlamaConfig,
MistralConfig,
MixtralConfig,
BaichuanConfig,
BaichuanTokenizer,
BaichuanForCausalLM,
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
gaudi_albert_forward,
Expand Down Expand Up @@ -626,3 +629,7 @@ def adapt_transformers_to_gaudi():
transformers.models.cohere.modeling_cohere.CohereForCausalLM = GaudiCohereForCausalLM
transformers.models.cohere.modeling_cohere.CohereModel.forward = gaudi_cohere_model_forward
transformers.models.cohere.modeling_cohere.CohereAttention.forward = gaudi_cohere_attention_forward

transformers.AutoConfig.register("baichuan", BaichuanConfig)
transformers.AutoTokenizer.register(BaichuanConfig, slow_tokenizer_class=BaichuanTokenizer)
transformers.AutoModelForCausalLM.register(BaichuanConfig, BaichuanForCausalLM)
6 changes: 6 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,9 @@
GaudiWhisperModel,
GaudiWhisperSdpaAttention,
)

from .baichuan import (
BaichuanConfig,
BaichuanTokenizer,
BaichuanForCausalLM,
)
5 changes: 5 additions & 0 deletions optimum/habana/transformers/models/baichuan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .configuration_baichuan import BaichuanConfig
from .tokenization_baichuan import BaichuanTokenizer
from .modeling_baichuan import (
BaichuanForCausalLM,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2023 Baichuan Inc. All Rights Reserved.

# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adapted from the following sources:
https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/configuration_baichuan.py
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/configuration_baichuan.py
"""

import sys
from transformers.configuration_utils import PretrainedConfig

class BaichuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=125696,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=sys.maxsize,
model_max_length=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
gradient_checkpointing=False,
z_loss_weight=0,
**kwargs,
):
self.vocab_size = vocab_size
# 13B config doesn't have max_position_embeddings
if max_position_embeddings < sys.maxsize:
self.max_position_embeddings = max_position_embeddings
self.model_max_length = model_max_length
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.z_loss_weight = z_loss_weight
self.gradient_checkpointing = (gradient_checkpointing,)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
109 changes: 109 additions & 0 deletions optimum/habana/transformers/models/baichuan/generation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2023 Baichuan Inc. All Rights Reserved.

# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adapted from the following sources:
https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/generation_utils.py
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_utils.py
"""

from typing import List
from queue import Queue

import torch


def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
def _parse_messages(messages, split_role="user"):
system, rounds = "", []
round = []
for i, message in enumerate(messages):
if message["role"] == "system":
assert i == 0
system = message["content"]
continue
if message["role"] == split_role and round:
rounds.append(round)
round = []
round.append(message)
if round:
rounds.append(round)
return system, rounds

max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
max_input_tokens = model.config.model_max_length - max_new_tokens
system, rounds = _parse_messages(messages, split_role="user")
system_tokens = tokenizer.encode(system)
max_history_tokens = max_input_tokens - len(system_tokens)

history_tokens = []
for round in rounds[::-1]:
round_tokens = []
for message in round:
if message["role"] == "user":
round_tokens.append(model.generation_config.user_token_id)
else:
round_tokens.append(model.generation_config.assistant_token_id)
round_tokens.extend(tokenizer.encode(message["content"]))
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break

input_tokens = system_tokens + history_tokens
if messages[-1]["role"] != "assistant":
input_tokens.append(model.generation_config.assistant_token_id)
input_tokens = input_tokens[-max_input_tokens:] # truncate left
return torch.LongTensor([input_tokens]).to(model.device)


class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True

def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))

def end(self):
self.text_queue.put(None)

def __iter__(self):
return self

def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value
Loading

0 comments on commit 952eb48

Please sign in to comment.