Skip to content

Commit

Permalink
Merge branch 'master' of github.com:FlagAI-Open/FlagAI
Browse files Browse the repository at this point in the history
  • Loading branch information
ldwang committed Sep 27, 2023
2 parents 40a72ad + 2632237 commit bf17821
Show file tree
Hide file tree
Showing 4 changed files with 617 additions and 108 deletions.
199 changes: 144 additions & 55 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import importlib
import os
import copy
from flagai.model.file_utils import _get_model_id

from flagai.model.file_utils import _get_model_id, _get_checkpoint_path, _get_vocab_path, _get_model_files
from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM
import torch

class LazyImport(object):

Expand All @@ -16,7 +17,7 @@ def __init__(self, name):
def __getattr__(self, name):
mod = self.cache.get(self.mod_name)
if not mod:
mod = importlib.import_module(self.mod_name)
mod = importlib.import_module(self.mod_name)
self.cache[self.mod_name] = mod
return getattr(mod, name)

Expand Down Expand Up @@ -163,7 +164,12 @@ def __init__(self,
model_name: str = "RoBERTa-base-ch",
model_dir: str = "./checkpoints/",
only_download_config: bool = False,
device="cpu",
device="cuda",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
lora_dir=None,
qlora_dir=None,
quantization_config=None,
**kwargs):
"""
Args:
Expand Down Expand Up @@ -194,66 +200,149 @@ def __init__(self,
raw_model_name = copy.deepcopy(model_name)
model_name = model_name.lower()

if model_name not in MODEL_DICT:
if model_name not in MODEL_DICT and task_name != "aquila2":
print(f"The model_name: {model_name} is not be supported")
print(f"All supported models are {list(MODEL_DICT.keys())}")
return
if task_name == "aquila2":
download_path = os.path.join(model_dir, model_name)

if not os.path.exists(download_path):
# Try to download from ModelHub
try:
model_id = _get_model_id(model_name)
except:
raise FileNotFoundError("Model name not found in local path and ModelHub")
if model_id and model_id != "null":
model_files = eval(_get_model_files(model_name))
print("model files:" + str(model_files))
for file_name in model_files:
if not file_name.endswith("bin"):
_get_vocab_path(download_path, file_name, model_id)

brief_model_name = MODEL_DICT[model_name][2]
model_type = MODEL_DICT[model_name][3]
# The dir to save config, vocab and model.
if os.path.exists(
os.path.join(download_path, 'config.json')):
if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
model_parallel_size = int(os.getenv("MODEL_PARALLEL_SIZE"))
if model_parallel_size > 1:
# if gpus == nums_of_modelhub_models
# can load
# else need to download the pytorch_model.bin and to recut.
model_hub_parallel_size = 0
for f in model_files:
if "pytorch_model_" in f:
model_hub_parallel_size += 1
else:
model_parallel_size = 1

self.model_name = ALL_TASK.get(f"{brief_model_name}_{task_name}", None)
if self.model_name is None:
print(f"For the model_name: {model_name}, task_name: {task_name} \
is not be supported.")
tasks = self.get_task_name(brief_model_name)
print(
f"For the model_name: {model_name}, these tasks are be supported: {tasks}"
)
return
download_path = os.path.join(model_dir, raw_model_name)
print("*" * 20, task_name, model_name)
model_name_ = self.is_exist_finetuned_model(raw_model_name, task_name)
self.model = getattr(LazyImport(self.model_name[0]),
self.model_name[1]).from_pretrain(
download_path=model_dir,
model_name=model_name_,
only_download_config=only_download_config,
device=device,
**kwargs)
if "pytorch_model_01.bin" in model_files and model_parallel_size > 1 and model_hub_parallel_size == model_parallel_size:
# Only to download the model slices(megatron-lm).
for file_to_load in model_files:
if "pytorch_model_" in file_to_load:
_get_checkpoint_path(download_path, file_to_load,
model_id)

if model_type == "nlp":
if brief_model_name in ["galactica",]:
self.tokenizer = getattr(LazyImport(MODEL_DICT[model_name][4]),
MODEL_DICT[model_name][5])(download_path)
else :
tokenizer_class = getattr(LazyImport("flagai.data.tokenizer"),
"Tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(
model_name, cache_dir=download_path)
elif 'pytorch_model.bin' in model_files:
checkpoint_path = _get_checkpoint_path(
download_path, 'pytorch_model.bin', model_id)
else:
checkpoint_merge = {}
# maybe multi weights files
for file_to_load in model_files:
if "pytorch_model-0" in file_to_load:
_get_checkpoint_path(download_path, file_to_load,
model_id)

if qlora_dir:
from transformers import BitsAndBytesConfig
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch_dtype,
)

elif model_type == "mm":
if model_name.startswith("altdiffusion"):
self.process = getattr(LazyImport(MODEL_DICT[model_name][4]),
MODEL_DICT[model_name][5]).from_pretrained(os.path.join(model_dir, raw_model_name))
self.tokenizer = self.process.tokenizer
self.model.tokenizer = self.tokenizer
elif "altclip" not in model_name:
from flagai.data.tokenizer.clip.tokenizer import ClipTokenizer
self.tokenizer = ClipTokenizer(bpe_path=os.path.join(download_path, 'bpe_simple_vocab_16e6.txt.gz'))
self.transform = None
else:

self.process = getattr(LazyImport(MODEL_DICT[model_name][4]),
MODEL_DICT[model_name][5]).from_pretrained(
os.path.join(model_dir, raw_model_name))
self.transform = self.process.feature_extractor
self.tokenizer = self.process.tokenizer

model = AquilaForCausalLM.from_pretrained(download_path,
low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,
quantization_config=quantization_config)

model.eval()
# from accelerate import load_checkpoint_and_dispatch
# model = load_checkpoint_and_dispatch(
# model, model_dir+model_name, device_map="balanced", no_split_module_classes=["LlamaDecoderLayer"])
if not qlora_dir:
model.to(device)
if lora_dir:
from flagai.model.tools.peft import PeftModel
model = PeftModel.from_pretrained(model, lora_dir)
print("lora modules loaded")
if qlora_dir:
from flagai.model.tools.peft import PeftModel
model = PeftModel.from_pretrained(model, qlora_dir)
print("Qlora modules loaded")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir+model_name)
self.model = model
self.tokenizer = tokenizer
else:
self.tokenizer = None
self.transform = None
brief_model_name = MODEL_DICT[model_name][2]
model_type = MODEL_DICT[model_name][3]
# The dir to save config, vocab and model.

self.model_name = ALL_TASK.get(f"{brief_model_name}_{task_name}", None)
if self.model_name is None:
print(f"For the model_name: {model_name}, task_name: {task_name} \
is not be supported.")
tasks = self.get_task_name(brief_model_name)
print(
f"For the model_name: {model_name}, these tasks are be supported: {tasks}"
)
return
download_path = os.path.join(model_dir, raw_model_name)
print("*" * 20, task_name, model_name)
model_name_ = self.is_exist_finetuned_model(raw_model_name, task_name)
self.model = getattr(LazyImport(self.model_name[0]),
self.model_name[1]).from_pretrain(
download_path=model_dir,
model_name=model_name_,
only_download_config=only_download_config,
device=device,
**kwargs)

if model_type == "nlp":
if brief_model_name in ["galactica",]:
self.tokenizer = getattr(LazyImport(MODEL_DICT[model_name][4]),
MODEL_DICT[model_name][5])(download_path)
# elif 'Aquila2-7b' in model_name:

else :
tokenizer_class = getattr(LazyImport("flagai.data.tokenizer"),
"Tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(
model_name, cache_dir=download_path)

elif model_type == "mm":
if model_name.startswith("altdiffusion"):
self.process = getattr(LazyImport(MODEL_DICT[model_name][4]),
MODEL_DICT[model_name][5]).from_pretrained(os.path.join(model_dir, raw_model_name))
self.tokenizer = self.process.tokenizer
self.model.tokenizer = self.tokenizer
elif "altclip" not in model_name:
from flagai.data.tokenizer.clip.tokenizer import ClipTokenizer
self.tokenizer = ClipTokenizer(bpe_path=os.path.join(download_path, 'bpe_simple_vocab_16e6.txt.gz'))
self.transform = None
else:

self.process = getattr(LazyImport(MODEL_DICT[model_name][4]),
MODEL_DICT[model_name][5]).from_pretrained(
os.path.join(model_dir, raw_model_name))
self.transform = self.process.feature_extractor
self.tokenizer = self.process.tokenizer

else:
self.tokenizer = None
self.transform = None

def is_exist_finetuned_model(self, raw_model_name, task_name):
try:
Expand Down
128 changes: 128 additions & 0 deletions flagai/model/aquila2/configuration_aquila.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Aquila model configuration"""

from transformers import PretrainedConfig



class AquilaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`AquilaModel`]. It is used to instantiate an Aquila
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Aquila-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Aquila model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`AquilaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
Example:
```python
>>> from transformers import AquilaModel, AquilaConfig
>>> # Initializing a Aquila aquila-7b style configuration
>>> configuration = AquilaConfig()
>>> # Initializing a model from the aquila-7b style configuration
>>> model = AquilaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "aquila"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=100008,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads

self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading

0 comments on commit bf17821

Please sign in to comment.