Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/feat/refactor3' int…
Browse files Browse the repository at this point in the history
…o feat/refactor3
  • Loading branch information
Jintao-Huang committed Oct 12, 2024
2 parents 6000ebf + f634d3c commit 53b5e5f
Show file tree
Hide file tree
Showing 22 changed files with 188 additions and 210 deletions.
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .argument import (EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments, RLHFArguments,
WebuiArguments, AppUIArguments)
from .template import TEMPLATE_MAPPING, Template, StopWords, get_template, TemplateType, register_template
from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, get_default_template_type, HfConfigFactory
from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory
from .dataset import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
TextGenerationPreprocessor, DatasetName, DatasetLoader, HubDatasetLoader, LocalDatasetLoader,
Expand Down Expand Up @@ -42,7 +42,7 @@
'RLHFArguments', 'AppUIArguments'
],
'template': ['TEMPLATE_MAPPING', 'Template', 'StopWords', 'get_template', 'TemplateType', 'register_template'],
'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'get_default_template_type', 'HfConfigFactory'],
'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory'],
'dataset': [
'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'ConversationsPreprocessor',
'ListPreprocessor', 'PreprocessFunc', 'RenameColumnsPreprocessor', 'SmartPreprocessor',
Expand Down
5 changes: 3 additions & 2 deletions swift/llm/argument/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Union

from swift.llm import DATASET_MAPPING, TEMPLATE_MAPPING, get_default_template_type, register_dataset_info_file
from swift.llm import DATASET_MAPPING, TEMPLATE_MAPPING, register_dataset_info_file
from swift.utils import get_logger

logger = get_logger()
Expand Down Expand Up @@ -70,7 +70,8 @@ def select_template(self: Union['SftArguments', 'InferArguments']):
# [TODO:]
"""If setting template to `auto`, find a proper one"""
if self.template_type == 'auto':
self.template_type = get_default_template_type(self.model_type)
from swift.llm.model.register import ModelInfoReader
self.template_type = ModelInfoReader.get_default_template_type(self.model_type)
logger.info(f'Setting template_type: {self.template_type}')

def __post_init__(self):
Expand Down
100 changes: 24 additions & 76 deletions swift/llm/argument/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from typing import List, Literal, Optional, Union

import torch
from transformers import AutoConfig
from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available
from transformers.utils.versions import require_version

from swift.llm import MODEL_KEYS_MAPPING, MODEL_MAPPING, ModelType, RLHFArguments
from swift.llm.model import fix_do_sample_warning
from swift.llm import MODEL_KEYS_MAPPING, MODEL_MAPPING
from swift.llm.model import fix_do_sample_warning, get_default_torch_dtype
from swift.utils import get_dist_setting, get_logger, use_hf_hub

logger = get_logger()

dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32', None: 'auto'}
dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32'}
dtype_mapping_reversed = {v: k for k, v in dtype_mapping.items()}


Expand Down Expand Up @@ -67,7 +68,6 @@ def select_bnb(self) -> None:
self.bnb_4bit_comp_dtype = self.dtype

bnb_4bit_compute_dtype = dtype_mapping_reversed[self.bnb_4bit_comp_dtype]
assert bnb_4bit_compute_dtype in {torch.float16, torch.bfloat16, torch.float32}

load_in_4bit, load_in_8bit = False, False # default value
if self.quant_method == 'bnb':
Expand All @@ -90,10 +90,10 @@ def __post_init__(self: Union['SftArguments', 'InferArguments']):
@dataclass
class ModelArguments:

# You can specify the model by either using the model_type or model_id_or_path.
# You can specify the model by model.
model: Optional[str] = None
model_type: Optional[str] = field(
default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'})
model_id_or_path: Optional[str] = None
model_revision: Optional[str] = None

dtype: Literal['bf16', 'fp16', 'fp32', 'auto'] = 'auto'
Expand All @@ -111,7 +111,7 @@ class ModelArguments:

def prepare_model_extra_args(self: 'SftArguments'):
"""Prepare model kwargs and set them to the env"""
self.parse_to_dict(self, 'model_kwargs')
self.parse_to_dict('model_kwargs')
for k, v in self.model_kwargs.items():
k = k.upper()
os.environ[k] = str(v)
Expand All @@ -126,22 +126,12 @@ def prepare_device_map_args(self: 'SftArguments'):
if isinstance(v, int):
self.device_map_config[k] += local_rank

def get_model_group(self):
"""Find the model group. This is used to find the model structure"""
model_type = (self.model_type or self.model_id_or_path).replace('-', '_')
model_group = None
for key in MODEL_KEYS_MAPPING.keys():
if key in model_type.lower():
model_group = key
break
return model_group

def check_flash_attn(self) -> None:
"""Some models do not support flash-attention"""
model_info = MODEL_MAPPING.get(self.model_type, {})
support_flash_attn = model_info.get('support_flash_attn', False)
if self.use_flash_attn and not support_flash_attn:
logger.warning(f'use_flash_attn: {self.use_flash_attn}, ' f'but support_flash_attn: {support_flash_attn}')
if 'flash' in self.attn_impl and not support_flash_attn:
logger.warning(f'attn_impl: {self.attn_impl}, ' f'but support_flash_attn: {support_flash_attn}')

@property
def is_multimodal(self) -> bool:
Expand All @@ -153,7 +143,7 @@ def is_multimodal(self) -> bool:
return 'multi-modal' in tags

def select_dtype(self) -> None:
"""If dtype is `auto`, find a proper dtype by the sft_type/GPU"""
"""If dtype is `auto`, find a proper dtype by the train_type/GPU"""
# Compatible with --fp16/--bf16
from .train_args import SftArguments
if isinstance(self, SftArguments):
Expand All @@ -165,26 +155,13 @@ def select_dtype(self) -> None:
break
# handle dtype == 'auto'
if self.dtype == 'auto':
if is_torch_cuda_available() or is_torch_npu_available():
if is_torch_bf16_gpu_available():
model_torch_dtype = MODEL_MAPPING[self.model_type].get('torch_dtype')
if model_torch_dtype is not None:
self.dtype = dtype_mapping[model_torch_dtype]
elif isinstance(self, SftArguments):
self.dtype = 'bf16'
# else: Keep 'auto'. According to the model's config.json file,
# this behavior is executed in the get_model_tokenizer function.
# This situation will only occur during inference.
else:
self.dtype = 'fp16'
else:
# cpu
self.dtype = 'fp32'
torch_dtype = get_default_torch_dtype(self.model_info.torch_dtype)
self.dtype = dtype_mapping[torch_dtype]
logger.info(f'Setting args.dtype: {self.dtype}')
# Check the validity of dtype
if is_torch_cuda_available() or is_torch_npu_available():
if self.dtype == 'fp16':
if isinstance(self, SftArguments) and self.sft_type == 'full':
if isinstance(self, SftArguments) and self.train_type == 'full':
self.dtype = 'fp32'
logger.warning(
'Fine-tuning with full parameters does not support fp16, and is prone to NaN. '
Expand All @@ -208,51 +185,22 @@ def select_dtype(self) -> None:
else:
raise ValueError(f'args.dtype: {self.dtype}')

def select_model_type(self) -> None:
"""model_type may be None, find the right one by `model_id_or_path`"""
from swift.llm.argument import InferArguments
if self.model_id_or_path is not None:
model_mapping_reversed = {}
for k, v in MODEL_MAPPING.items():
if use_hf_hub():
model_id = v.get('hf_model_id')
else:
model_id = v.get('model_id_or_path')
if model_id is None:
continue
model_id = model_id.lower()
model_mapping_reversed[model_id] = k
model_id_or_path = self.model_id_or_path
model_id_or_path_lower = model_id_or_path.lower()

if self.model_type is None and model_id_or_path_lower in model_mapping_reversed:
model_type = model_mapping_reversed[model_id_or_path_lower]
assert self.model_type is None or self.model_type == model_type
self.model_type = model_type
logger.info(f'Setting args.model_type: {model_type}')
else:
if (isinstance(self, InferArguments) and 'checkpoint-' in model_id_or_path
and 'merged' not in model_id_or_path and self.ckpt_dir is None):
raise ValueError('Please use `--ckpt_dir vx-xxx/checkpoint-xxx` to use the checkpoint.')

model_info = MODEL_MAPPING.get(self.model_type, {})
if self.model_revision is not None:
model_info['revision'] = self.model_revision
logger.info(f"Setting model_info['revision']: {self.model_revision}")
elif use_hf_hub():
model_info['revision'] = 'main'
self.model_revision = model_info['revision']
if self.model_id_or_path is None:
self.model_id_or_path = model_info['hf_model_id'] if use_hf_hub() else model_info['model_id_or_path']
requires = model_info.get('requires', [])
for require in requires:
require_version(require)
def prepare_model_info(self) -> None:
from swift.llm import safe_snapshot_download, HfConfigFactory
model_dir = safe_snapshot_download(self.model, revision=self.model_revision, load_model=False)
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
model_info = HfConfigFactory.get_model_info(model_config, model_dir)
self.model_info = model_info
if self.model_type is None:
self.model_type = model_info.model_type
if self.quant_method is None:
self.quant_method = model_info.quant_method

def __post_init__(self: Union['SftArguments', 'InferArguments']):
if self.rope_scaling:
logger.info(f'rope_scaling is set to {self.rope_scaling}, please remember to set max_length')
self.prepare_model_extra_args()
self.prepare_device_map_args()
self.prepare_model_info()
self.check_flash_attn()
self.select_dtype()
self.select_model_type()
6 changes: 3 additions & 3 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,13 @@ def prepare_train_type(self):
if self.is_adapter:
assert self.freeze_parameters_ratio == 0. and len(
self.additional_trainable_parameters) == 0, ('lora does not support, please use `--train_type full`')
if self.is_quant_model():
if self.quant_method is not None:
assert self.quantization_bit == 0, (
f'{self.model_type} is already a quantized model and does not need to be quantized again.')
elif self.train_type == 'full':
if self.freeze_vit:
if self.get_model_group():
vision_tower = MODEL_KEYS_MAPPING[self.get_model_group()].vision_tower
if self.model_type in MODEL_KEYS_MAPPING:
vision_tower = MODEL_KEYS_MAPPING[self.model_type].vision_tower
if vision_tower:
self.freeze_parameters += vision_tower
assert 0 <= self.freeze_parameters_ratio <= 1
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def eval_opencompass(args: EvalArguments) -> List[Dict[str, Any]]:
port = args.port
# health check: try to get model_type until raises
get_model_type(port, args.deploy_timeout)
model_type = 'default-lora' if args.sft_type in ('lora',
'longlora') and not args.merge_lora else args.model_type
model_type = 'default-lora' if args.train_type in ('lora',
'longlora') and not args.merge_lora else args.model_type
from swift.llm.infer.deploy import is_generation_template
if is_generation_template(args.template_type):
url = f'http://127.0.0.1:{port}/v1/completions'
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def llm_export(args: ExportArguments) -> None:
logger.info(f'args: {args}')
seed_everything(args.seed)
if args.to_peft_format:
assert args.sft_type == 'lora', f'args.sft_type: {args.sft_type}'
assert args.train_type == 'lora', f'args.train_type: {args.train_type}'
args.ckpt_dir = swift_to_peft_format(args.ckpt_dir)

if args.merge_lora:
Expand Down Expand Up @@ -248,7 +248,7 @@ def llm_export(args: ExportArguments) -> None:
assert args.quant_output_dir is not None
_args = args
assert args.quantization_bit == 0, f'args.quantization_bit: {args.quantization_bit}'
assert args.sft_type == 'full', 'you need to merge lora'
assert args.train_type == 'full', 'you need to merge lora'
if args.quant_method == 'awq':
from awq import AutoAWQForCausalLM
model, template = prepare_model_template(
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def merge_lora(args: InferArguments,
**kwargs) -> Optional[str]:
logger.info(f'replace_if_exists: {replace_if_exists}')
assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.'
assert args.sft_type in args.adapters_can_be_merged, 'Only supports lora & llamapro series models'
assert args.train_type in args.adapters_can_be_merged, 'Only supports lora & llamapro series models'
assert not args.is_quant_model(), f'{args.model_type} is a quantized model and does not support merge-lora.'
if args.quantization_bit != 0:
logger.warning('It is not recommended to merge quantized models, '
Expand Down Expand Up @@ -141,9 +141,9 @@ def merge_lora(args: InferArguments,
if tempdir:
shutil.rmtree(tempdir, ignore_errors=True)

logger.info("Setting args.sft_type: 'full'")
logger.info("Setting args.train_type: 'full'")
logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
args.sft_type = 'full'
args.train_type = 'full'
args.ckpt_dir = merged_lora_path
return merged_lora_path

Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def __init__(self, args: InferArguments, use_async: bool = False, **kwargs):
logger.info(f'device_count: {torch.cuda.device_count()}')

assert args.quantization_bit == 0, 'not support bnb'
assert not args.sft_type == 'lora', 'you need to merge lora'
assert not args.train_type == 'lora', 'you need to merge lora'
# Loading Model and Tokenizer
model_id_or_path = None
if args.sft_type == 'full' and args.ckpt_dir is not None:
if args.train_type == 'full' and args.ckpt_dir is not None:
model_id_or_path = args.ckpt_dir
elif args.model_id_or_path is not None:
model_id_or_path = args.model_id_or_path
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/infer/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def prepare_model_template_hf(args: InferArguments, use_async: bool = False, **k
if args.use_flash_attn is not None:
kwargs['use_flash_attn'] = args.use_flash_attn
model_id_or_path = None
if args.sft_type == 'full' and args.ckpt_dir is not None:
if args.train_type == 'full' and args.ckpt_dir is not None:
model_id_or_path = args.ckpt_dir
elif args.model_id_or_path is not None:
model_id_or_path = args.model_id_or_path
Expand Down
46 changes: 23 additions & 23 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PretrainedConfig,
PreTrainedModel, PreTrainedTokenizerBase)
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available
from transformers.utils.versions import require_version

from swift.llm import TemplateType
Expand Down Expand Up @@ -165,7 +165,7 @@ def get_model_tokenizer_from_local(model_dir: str,
logger.info(f'model_kwargs: {model_kwargs}')
model = automodel_class.from_pretrained(
model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs)
model.quant_method = kwargs.get('quant_method')
model.quant_method = kwargs.get('quant_method') # TODO: check bnb
model.quant_bits = kwargs.get('bits')
model.is_training = kwargs.get('is_training', False)
max_model_len = HfConfigFactory.get_max_model_len(model_config)
Expand Down Expand Up @@ -261,15 +261,20 @@ def get_default_device_map():
return 'auto'


def get_default_torch_dtype(model_config: PretrainedConfig, quant_info: Optional[Dict[str, Any]] = None) -> torch.dtype:
torch_dtype = HfConfigFactory.get_torch_dtype(model_config)
if torch_dtype is None:
if quant_info is None:
quant_info = HfConfigFactory.get_quant_info(model_config)
torch_dtype = quant_info.get('torch_dtype')
if torch_dtype in {torch.float32, None}:
torch_dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
return torch_dtype
def get_default_torch_dtype(torch_dtype: torch.dtype):
# torch_dtype: torch_dtype in config.json
if is_torch_cuda_available() or is_torch_npu_available():
if is_torch_bf16_gpu_available():
if torch_dtype in {torch.float16, torch.bfloat16}:
res = torch_dtype
else:
res = torch.bfloat16
else:
res = torch.float16
else:
# cpu
res = torch.float32
return res


def get_model_tokenizer(model_id_or_path: str,
Expand Down Expand Up @@ -309,23 +314,18 @@ def get_model_tokenizer(model_id_or_path: str,
model_kwargs['device_map'] = get_default_device_map()
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
kwargs['model_config'] = model_config
model_info = HfConfigFactory.get_model_info(model_config, model_dir)
if model_type is None:
model_types = HfConfigFactory.get_matched_model_types(model_config, model_dir)
if len(model_types) > 1:
raise ValueError('Unable to obtain the accurate model_type based on the model architecture. '
f'Please explicitly provide the model_type. Available model_types: {model_types}')
model_type = model_types[0]
model_type = model_info.model_type
logger.info(f'Setting model_type: {model_type}')
quant_info = HfConfigFactory.get_quant_info(model_config)
if torch_dtype is None:
torch_dtype = get_default_torch_dtype(model_config, quant_info)
torch_dtype = get_default_torch_dtype(model_info.torch_dtype)
logger.info(f'Setting torch_dtype: {torch_dtype}')

if quant_info is not None:
quant_info.pop('torch_dtype', None)
kwargs.update(quant_info)

if model_info.quant_method is not None:
kwargs['quant_method'] = model_info.quant_method
kwargs['bits'] = model_info.bits
kwargs.update({'model_type': model_type, 'attn_impl': attn_impl})

model_info = MODEL_MAPPING[model_type]
requires = model_info['requires']
for require in requires:
Expand Down
Loading

0 comments on commit 53b5e5f

Please sign in to comment.