-
Notifications
You must be signed in to change notification settings - Fork 200
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Below is an example for baichuan-inc/Baichuan2-7B-Chat: python3 run_generation.py \ --model_name_or_path baichuan-inc/Baichuan2-7B-Chat \ --bf16 --trim_logits --batch_size 1 \ --max_input_tokens 1024 --max_new_tokens 512 \ --use_kv_cache --use_hpu_graphs --use_flash_attention \ --reuse_cache \ --no-ignore_eos Below is an example for baichuan-inc/Baichuan2-13B-Chat: python3 run_generation.py \ --model_name_or_path baichuan-inc/Baichuan2-13B-Chat \ --bf16 --trim_logits --batch_size 1 \ --max_input_tokens 1024 --max_new_tokens 512 \ --use_kv_cache --use_hpu_graphs --bucket_size 256 \ --bucket_internal --reuse_cache \ --no-ignore_eos Co-authored-by: Jianqian Zhou <[email protected]> Co-authored-by: Wei Lin <[email protected]> Signed-off-by: Haihao Xiang <[email protected]>
- Loading branch information
1 parent
7ac1db1
commit 952eb48
Showing
14 changed files
with
2,167 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .configuration_baichuan import BaichuanConfig | ||
from .tokenization_baichuan import BaichuanTokenizer | ||
from .modeling_baichuan import ( | ||
BaichuanForCausalLM, | ||
) |
77 changes: 77 additions & 0 deletions
77
optimum/habana/transformers/models/baichuan/configuration_baichuan.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright 2023 Baichuan Inc. All Rights Reserved. | ||
|
||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | ||
# | ||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | ||
# and OPT implementations in this library. It has been modified from its | ||
# original forms to accommodate minor architectural differences compared | ||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Adapted from the following sources: | ||
https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/configuration_baichuan.py | ||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/configuration_baichuan.py | ||
""" | ||
|
||
import sys | ||
from transformers.configuration_utils import PretrainedConfig | ||
|
||
class BaichuanConfig(PretrainedConfig): | ||
model_type = "baichuan" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
|
||
def __init__( | ||
self, | ||
vocab_size=125696, | ||
hidden_size=4096, | ||
intermediate_size=11008, | ||
num_hidden_layers=32, | ||
num_attention_heads=32, | ||
hidden_act="silu", | ||
max_position_embeddings=sys.maxsize, | ||
model_max_length=4096, | ||
initializer_range=0.02, | ||
rms_norm_eps=1e-6, | ||
use_cache=True, | ||
pad_token_id=0, | ||
bos_token_id=1, | ||
eos_token_id=2, | ||
tie_word_embeddings=False, | ||
gradient_checkpointing=False, | ||
z_loss_weight=0, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
# 13B config doesn't have max_position_embeddings | ||
if max_position_embeddings < sys.maxsize: | ||
self.max_position_embeddings = max_position_embeddings | ||
self.model_max_length = model_max_length | ||
self.hidden_size = hidden_size | ||
self.intermediate_size = intermediate_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.hidden_act = hidden_act | ||
self.initializer_range = initializer_range | ||
self.rms_norm_eps = rms_norm_eps | ||
self.use_cache = use_cache | ||
self.z_loss_weight = z_loss_weight | ||
self.gradient_checkpointing = (gradient_checkpointing,) | ||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
eos_token_id=eos_token_id, | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) |
109 changes: 109 additions & 0 deletions
109
optimum/habana/transformers/models/baichuan/generation_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright 2023 Baichuan Inc. All Rights Reserved. | ||
|
||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | ||
# | ||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | ||
# and OPT implementations in this library. It has been modified from its | ||
# original forms to accommodate minor architectural differences compared | ||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Adapted from the following sources: | ||
https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/generation_utils.py | ||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_utils.py | ||
""" | ||
|
||
from typing import List | ||
from queue import Queue | ||
|
||
import torch | ||
|
||
|
||
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0): | ||
def _parse_messages(messages, split_role="user"): | ||
system, rounds = "", [] | ||
round = [] | ||
for i, message in enumerate(messages): | ||
if message["role"] == "system": | ||
assert i == 0 | ||
system = message["content"] | ||
continue | ||
if message["role"] == split_role and round: | ||
rounds.append(round) | ||
round = [] | ||
round.append(message) | ||
if round: | ||
rounds.append(round) | ||
return system, rounds | ||
|
||
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens | ||
max_input_tokens = model.config.model_max_length - max_new_tokens | ||
system, rounds = _parse_messages(messages, split_role="user") | ||
system_tokens = tokenizer.encode(system) | ||
max_history_tokens = max_input_tokens - len(system_tokens) | ||
|
||
history_tokens = [] | ||
for round in rounds[::-1]: | ||
round_tokens = [] | ||
for message in round: | ||
if message["role"] == "user": | ||
round_tokens.append(model.generation_config.user_token_id) | ||
else: | ||
round_tokens.append(model.generation_config.assistant_token_id) | ||
round_tokens.extend(tokenizer.encode(message["content"])) | ||
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: | ||
history_tokens = round_tokens + history_tokens # concat left | ||
if len(history_tokens) < max_history_tokens: | ||
continue | ||
break | ||
|
||
input_tokens = system_tokens + history_tokens | ||
if messages[-1]["role"] != "assistant": | ||
input_tokens.append(model.generation_config.assistant_token_id) | ||
input_tokens = input_tokens[-max_input_tokens:] # truncate left | ||
return torch.LongTensor([input_tokens]).to(model.device) | ||
|
||
|
||
class TextIterStreamer: | ||
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): | ||
self.tokenizer = tokenizer | ||
self.skip_prompt = skip_prompt | ||
self.skip_special_tokens = skip_special_tokens | ||
self.tokens = [] | ||
self.text_queue = Queue() | ||
self.next_tokens_are_prompt = True | ||
|
||
def put(self, value): | ||
if self.skip_prompt and self.next_tokens_are_prompt: | ||
self.next_tokens_are_prompt = False | ||
else: | ||
if len(value.shape) > 1: | ||
value = value[0] | ||
self.tokens.extend(value.tolist()) | ||
self.text_queue.put( | ||
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) | ||
|
||
def end(self): | ||
self.text_queue.put(None) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
value = self.text_queue.get() | ||
if value is None: | ||
raise StopIteration() | ||
else: | ||
return value |
Oops, something went wrong.