diff --git a/flagai/auto_model/auto_loader.py b/flagai/auto_model/auto_loader.py index 68521601..726dc340 100755 --- a/flagai/auto_model/auto_loader.py +++ b/flagai/auto_model/auto_loader.py @@ -210,7 +210,10 @@ def __init__(self, if task_name == "aquila2": from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM download_path = os.path.join(model_dir, model_name) - + + if not torch_dtype and '34b' in model_name.lower(): + torch_dtype = torch.bfloat16 + if not os.path.exists(download_path): # Try to download from ModelHub try: @@ -255,9 +258,10 @@ def __init__(self, for file_to_load in model_files: if "pytorch_model-0" in file_to_load: _get_checkpoint_path(download_path, file_to_load, - model_id) - - if qlora_dir: + model_id) + if 'quantization_config' in kwargs: + quantization_config = kwargs['quantization_config'] + elif qlora_dir: from transformers import BitsAndBytesConfig quantization_config=BitsAndBytesConfig( load_in_4bit=True, @@ -265,14 +269,14 @@ def __init__(self, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, ) + else: + quantization_config = None if inference_mode: - if qlora_dir: - model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, - quantization_config=quantization_config) - else: - model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype,) + + model = AquilaForCausalLM.from_pretrained(download_path,low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, + quantization_config=quantization_config) model.eval() - if not qlora_dir: + if not quantization_config: model.to(device) if lora_dir: from flagai.model.tools.peft import PeftModel diff --git a/setup.py b/setup.py index 3a929b7b..da686a57 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="flagai", - version="v1.8.0", + version="v1.8.1", description="FlagAI aims to help researchers and developers to freely train and test large-scale models for NLP/CV/VL tasks.", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -19,19 +19,19 @@ install_requires=[ 'nltk>=3.6.7', 'sentencepiece>=0.1.96', - 'boto3==1.17.32', + 'boto3>=1.17.32', 'pandas>=1.3.5', 'jieba>=0.42.1', 'scikit-learn>=1.0.2', 'tensorboard>=2.9.0', 'transformers>=4.31.0', 'datasets>=2.0.0', - 'setuptools==66.0.0', + 'setuptools>=66.0.0', 'protobuf==3.19.6', 'ftfy', 'Pillow>=9.3.0', 'einops>=0.3.0', - 'diffusers==0.7.2', + 'diffusers>=0.7.2', 'pytorch-lightning>=1.6.5', 'taming-transformers-rom1504==0.0.6', 'rouge-score',