diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index c5b08daf1..46e0c6892 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -14,7 +14,7 @@ from .argument import (EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments, RLHFArguments, WebuiArguments, AppUIArguments) from .template import TEMPLATE_MAPPING, Template, StopWords, get_template, TemplateType, register_template - from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, get_default_template_type, HfConfigFactory + from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download, HfConfigFactory from .dataset import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor, ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor, TextGenerationPreprocessor, DatasetName, DatasetLoader, HubDatasetLoader, LocalDatasetLoader, @@ -42,7 +42,7 @@ 'RLHFArguments', 'AppUIArguments' ], 'template': ['TEMPLATE_MAPPING', 'Template', 'StopWords', 'get_template', 'TemplateType', 'register_template'], - 'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'get_default_template_type', 'HfConfigFactory'], + 'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'safe_snapshot_download', 'HfConfigFactory'], 'dataset': [ 'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'ConversationsPreprocessor', 'ListPreprocessor', 'PreprocessFunc', 'RenameColumnsPreprocessor', 'SmartPreprocessor', diff --git a/swift/llm/argument/data_args.py b/swift/llm/argument/data_args.py index 27e37918b..8128d29ac 100644 --- a/swift/llm/argument/data_args.py +++ b/swift/llm/argument/data_args.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import List, Literal, Optional, Union -from swift.llm import DATASET_MAPPING, TEMPLATE_MAPPING, get_default_template_type, register_dataset_info_file +from swift.llm import DATASET_MAPPING, TEMPLATE_MAPPING, register_dataset_info_file from swift.utils import get_logger logger = get_logger() @@ -70,7 +70,8 @@ def select_template(self: Union['SftArguments', 'InferArguments']): # [TODO:] """If setting template to `auto`, find a proper one""" if self.template_type == 'auto': - self.template_type = get_default_template_type(self.model_type) + from swift.llm.model.register import ModelInfoReader + self.template_type = ModelInfoReader.get_default_template_type(self.model_type) logger.info(f'Setting template_type: {self.template_type}') def __post_init__(self): diff --git a/swift/llm/argument/model_args.py b/swift/llm/argument/model_args.py index d1a01ff1f..443691611 100644 --- a/swift/llm/argument/model_args.py +++ b/swift/llm/argument/model_args.py @@ -4,16 +4,17 @@ from typing import List, Literal, Optional, Union import torch +from transformers import AutoConfig from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available from transformers.utils.versions import require_version -from swift.llm import MODEL_KEYS_MAPPING, MODEL_MAPPING, ModelType, RLHFArguments -from swift.llm.model import fix_do_sample_warning +from swift.llm import MODEL_KEYS_MAPPING, MODEL_MAPPING +from swift.llm.model import fix_do_sample_warning, get_default_torch_dtype from swift.utils import get_dist_setting, get_logger, use_hf_hub logger = get_logger() -dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32', None: 'auto'} +dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32'} dtype_mapping_reversed = {v: k for k, v in dtype_mapping.items()} @@ -67,7 +68,6 @@ def select_bnb(self) -> None: self.bnb_4bit_comp_dtype = self.dtype bnb_4bit_compute_dtype = dtype_mapping_reversed[self.bnb_4bit_comp_dtype] - assert bnb_4bit_compute_dtype in {torch.float16, torch.bfloat16, torch.float32} load_in_4bit, load_in_8bit = False, False # default value if self.quant_method == 'bnb': @@ -90,10 +90,10 @@ def __post_init__(self: Union['SftArguments', 'InferArguments']): @dataclass class ModelArguments: - # You can specify the model by either using the model_type or model_id_or_path. + # You can specify the model by model. + model: Optional[str] = None model_type: Optional[str] = field( default=None, metadata={'help': f'model_type choices: {list(MODEL_MAPPING.keys())}'}) - model_id_or_path: Optional[str] = None model_revision: Optional[str] = None dtype: Literal['bf16', 'fp16', 'fp32', 'auto'] = 'auto' @@ -111,7 +111,7 @@ class ModelArguments: def prepare_model_extra_args(self: 'SftArguments'): """Prepare model kwargs and set them to the env""" - self.parse_to_dict(self, 'model_kwargs') + self.parse_to_dict('model_kwargs') for k, v in self.model_kwargs.items(): k = k.upper() os.environ[k] = str(v) @@ -126,22 +126,12 @@ def prepare_device_map_args(self: 'SftArguments'): if isinstance(v, int): self.device_map_config[k] += local_rank - def get_model_group(self): - """Find the model group. This is used to find the model structure""" - model_type = (self.model_type or self.model_id_or_path).replace('-', '_') - model_group = None - for key in MODEL_KEYS_MAPPING.keys(): - if key in model_type.lower(): - model_group = key - break - return model_group - def check_flash_attn(self) -> None: """Some models do not support flash-attention""" model_info = MODEL_MAPPING.get(self.model_type, {}) support_flash_attn = model_info.get('support_flash_attn', False) - if self.use_flash_attn and not support_flash_attn: - logger.warning(f'use_flash_attn: {self.use_flash_attn}, ' f'but support_flash_attn: {support_flash_attn}') + if 'flash' in self.attn_impl and not support_flash_attn: + logger.warning(f'attn_impl: {self.attn_impl}, ' f'but support_flash_attn: {support_flash_attn}') @property def is_multimodal(self) -> bool: @@ -153,7 +143,7 @@ def is_multimodal(self) -> bool: return 'multi-modal' in tags def select_dtype(self) -> None: - """If dtype is `auto`, find a proper dtype by the sft_type/GPU""" + """If dtype is `auto`, find a proper dtype by the train_type/GPU""" # Compatible with --fp16/--bf16 from .train_args import SftArguments if isinstance(self, SftArguments): @@ -165,26 +155,13 @@ def select_dtype(self) -> None: break # handle dtype == 'auto' if self.dtype == 'auto': - if is_torch_cuda_available() or is_torch_npu_available(): - if is_torch_bf16_gpu_available(): - model_torch_dtype = MODEL_MAPPING[self.model_type].get('torch_dtype') - if model_torch_dtype is not None: - self.dtype = dtype_mapping[model_torch_dtype] - elif isinstance(self, SftArguments): - self.dtype = 'bf16' - # else: Keep 'auto'. According to the model's config.json file, - # this behavior is executed in the get_model_tokenizer function. - # This situation will only occur during inference. - else: - self.dtype = 'fp16' - else: - # cpu - self.dtype = 'fp32' + torch_dtype = get_default_torch_dtype(self.model_info.torch_dtype) + self.dtype = dtype_mapping[torch_dtype] logger.info(f'Setting args.dtype: {self.dtype}') # Check the validity of dtype if is_torch_cuda_available() or is_torch_npu_available(): if self.dtype == 'fp16': - if isinstance(self, SftArguments) and self.sft_type == 'full': + if isinstance(self, SftArguments) and self.train_type == 'full': self.dtype = 'fp32' logger.warning( 'Fine-tuning with full parameters does not support fp16, and is prone to NaN. ' @@ -208,51 +185,22 @@ def select_dtype(self) -> None: else: raise ValueError(f'args.dtype: {self.dtype}') - def select_model_type(self) -> None: - """model_type may be None, find the right one by `model_id_or_path`""" - from swift.llm.argument import InferArguments - if self.model_id_or_path is not None: - model_mapping_reversed = {} - for k, v in MODEL_MAPPING.items(): - if use_hf_hub(): - model_id = v.get('hf_model_id') - else: - model_id = v.get('model_id_or_path') - if model_id is None: - continue - model_id = model_id.lower() - model_mapping_reversed[model_id] = k - model_id_or_path = self.model_id_or_path - model_id_or_path_lower = model_id_or_path.lower() - - if self.model_type is None and model_id_or_path_lower in model_mapping_reversed: - model_type = model_mapping_reversed[model_id_or_path_lower] - assert self.model_type is None or self.model_type == model_type - self.model_type = model_type - logger.info(f'Setting args.model_type: {model_type}') - else: - if (isinstance(self, InferArguments) and 'checkpoint-' in model_id_or_path - and 'merged' not in model_id_or_path and self.ckpt_dir is None): - raise ValueError('Please use `--ckpt_dir vx-xxx/checkpoint-xxx` to use the checkpoint.') - - model_info = MODEL_MAPPING.get(self.model_type, {}) - if self.model_revision is not None: - model_info['revision'] = self.model_revision - logger.info(f"Setting model_info['revision']: {self.model_revision}") - elif use_hf_hub(): - model_info['revision'] = 'main' - self.model_revision = model_info['revision'] - if self.model_id_or_path is None: - self.model_id_or_path = model_info['hf_model_id'] if use_hf_hub() else model_info['model_id_or_path'] - requires = model_info.get('requires', []) - for require in requires: - require_version(require) + def prepare_model_info(self) -> None: + from swift.llm import safe_snapshot_download, HfConfigFactory + model_dir = safe_snapshot_download(self.model, revision=self.model_revision, load_model=False) + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model_info = HfConfigFactory.get_model_info(model_config, model_dir) + self.model_info = model_info + if self.model_type is None: + self.model_type = model_info.model_type + if self.quant_method is None: + self.quant_method = model_info.quant_method def __post_init__(self: Union['SftArguments', 'InferArguments']): if self.rope_scaling: logger.info(f'rope_scaling is set to {self.rope_scaling}, please remember to set max_length') self.prepare_model_extra_args() self.prepare_device_map_args() + self.prepare_model_info() self.check_flash_attn() self.select_dtype() - self.select_model_type() diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index e324f0cff..4b3b4ec4b 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -317,13 +317,13 @@ def prepare_train_type(self): if self.is_adapter: assert self.freeze_parameters_ratio == 0. and len( self.additional_trainable_parameters) == 0, ('lora does not support, please use `--train_type full`') - if self.is_quant_model(): + if self.quant_method is not None: assert self.quantization_bit == 0, ( f'{self.model_type} is already a quantized model and does not need to be quantized again.') elif self.train_type == 'full': if self.freeze_vit: - if self.get_model_group(): - vision_tower = MODEL_KEYS_MAPPING[self.get_model_group()].vision_tower + if self.model_type in MODEL_KEYS_MAPPING: + vision_tower = MODEL_KEYS_MAPPING[self.model_type].vision_tower if vision_tower: self.freeze_parameters += vision_tower assert 0 <= self.freeze_parameters_ratio <= 1 diff --git a/swift/llm/eval/eval.py b/swift/llm/eval/eval.py index 2673678a0..1b8904171 100644 --- a/swift/llm/eval/eval.py +++ b/swift/llm/eval/eval.py @@ -290,8 +290,8 @@ def eval_opencompass(args: EvalArguments) -> List[Dict[str, Any]]: port = args.port # health check: try to get model_type until raises get_model_type(port, args.deploy_timeout) - model_type = 'default-lora' if args.sft_type in ('lora', - 'longlora') and not args.merge_lora else args.model_type + model_type = 'default-lora' if args.train_type in ('lora', + 'longlora') and not args.merge_lora else args.model_type from swift.llm.infer.deploy import is_generation_template if is_generation_template(args.template_type): url = f'http://127.0.0.1:{port}/v1/completions' diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index fa935be8d..8e98521ae 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -184,7 +184,7 @@ def llm_export(args: ExportArguments) -> None: logger.info(f'args: {args}') seed_everything(args.seed) if args.to_peft_format: - assert args.sft_type == 'lora', f'args.sft_type: {args.sft_type}' + assert args.train_type == 'lora', f'args.train_type: {args.train_type}' args.ckpt_dir = swift_to_peft_format(args.ckpt_dir) if args.merge_lora: @@ -248,7 +248,7 @@ def llm_export(args: ExportArguments) -> None: assert args.quant_output_dir is not None _args = args assert args.quantization_bit == 0, f'args.quantization_bit: {args.quantization_bit}' - assert args.sft_type == 'full', 'you need to merge lora' + assert args.train_type == 'full', 'you need to merge lora' if args.quant_method == 'awq': from awq import AutoAWQForCausalLM model, template = prepare_model_template( diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 82115081f..5bf6ba6dd 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -90,7 +90,7 @@ def merge_lora(args: InferArguments, **kwargs) -> Optional[str]: logger.info(f'replace_if_exists: {replace_if_exists}') assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.' - assert args.sft_type in args.adapters_can_be_merged, 'Only supports lora & llamapro series models' + assert args.train_type in args.adapters_can_be_merged, 'Only supports lora & llamapro series models' assert not args.is_quant_model(), f'{args.model_type} is a quantized model and does not support merge-lora.' if args.quantization_bit != 0: logger.warning('It is not recommended to merge quantized models, ' @@ -141,9 +141,9 @@ def merge_lora(args: InferArguments, if tempdir: shutil.rmtree(tempdir, ignore_errors=True) - logger.info("Setting args.sft_type: 'full'") + logger.info("Setting args.train_type: 'full'") logger.info(f'Setting args.ckpt_dir: {merged_lora_path}') - args.sft_type = 'full' + args.train_type = 'full' args.ckpt_dir = merged_lora_path return merged_lora_path diff --git a/swift/llm/infer/lmdeploy.py b/swift/llm/infer/lmdeploy.py index d3f6f2223..9a1af1296 100644 --- a/swift/llm/infer/lmdeploy.py +++ b/swift/llm/infer/lmdeploy.py @@ -77,10 +77,10 @@ def __init__(self, args: InferArguments, use_async: bool = False, **kwargs): logger.info(f'device_count: {torch.cuda.device_count()}') assert args.quantization_bit == 0, 'not support bnb' - assert not args.sft_type == 'lora', 'you need to merge lora' + assert not args.train_type == 'lora', 'you need to merge lora' # Loading Model and Tokenizer model_id_or_path = None - if args.sft_type == 'full' and args.ckpt_dir is not None: + if args.train_type == 'full' and args.ckpt_dir is not None: model_id_or_path = args.ckpt_dir elif args.model_id_or_path is not None: model_id_or_path = args.model_id_or_path diff --git a/swift/llm/infer/transformers.py b/swift/llm/infer/transformers.py index 7feddfdeb..2054565a8 100644 --- a/swift/llm/infer/transformers.py +++ b/swift/llm/infer/transformers.py @@ -115,7 +115,7 @@ def prepare_model_template_hf(args: InferArguments, use_async: bool = False, **k if args.use_flash_attn is not None: kwargs['use_flash_attn'] = args.use_flash_attn model_id_or_path = None - if args.sft_type == 'full' and args.ckpt_dir is not None: + if args.train_type == 'full' and args.ckpt_dir is not None: model_id_or_path = args.ckpt_dir elif args.model_id_or_path is not None: model_id_or_path = args.model_id_or_path diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index 69a0aec68..35f123799 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -11,7 +11,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase) from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available +from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available from transformers.utils.versions import require_version from swift.llm import TemplateType @@ -165,7 +165,7 @@ def get_model_tokenizer_from_local(model_dir: str, logger.info(f'model_kwargs: {model_kwargs}') model = automodel_class.from_pretrained( model_dir, config=model_config, torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs) - model.quant_method = kwargs.get('quant_method') + model.quant_method = kwargs.get('quant_method') # TODO: check bnb model.quant_bits = kwargs.get('bits') model.is_training = kwargs.get('is_training', False) max_model_len = HfConfigFactory.get_max_model_len(model_config) @@ -261,15 +261,20 @@ def get_default_device_map(): return 'auto' -def get_default_torch_dtype(model_config: PretrainedConfig, quant_info: Optional[Dict[str, Any]] = None) -> torch.dtype: - torch_dtype = HfConfigFactory.get_torch_dtype(model_config) - if torch_dtype is None: - if quant_info is None: - quant_info = HfConfigFactory.get_quant_info(model_config) - torch_dtype = quant_info.get('torch_dtype') - if torch_dtype in {torch.float32, None}: - torch_dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 - return torch_dtype +def get_default_torch_dtype(torch_dtype: torch.dtype): + # torch_dtype: torch_dtype in config.json + if is_torch_cuda_available() or is_torch_npu_available(): + if is_torch_bf16_gpu_available(): + if torch_dtype in {torch.float16, torch.bfloat16}: + res = torch_dtype + else: + res = torch.bfloat16 + else: + res = torch.float16 + else: + # cpu + res = torch.float32 + return res def get_model_tokenizer(model_id_or_path: str, @@ -309,23 +314,18 @@ def get_model_tokenizer(model_id_or_path: str, model_kwargs['device_map'] = get_default_device_map() model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) kwargs['model_config'] = model_config + model_info = HfConfigFactory.get_model_info(model_config, model_dir) if model_type is None: - model_types = HfConfigFactory.get_matched_model_types(model_config, model_dir) - if len(model_types) > 1: - raise ValueError('Unable to obtain the accurate model_type based on the model architecture. ' - f'Please explicitly provide the model_type. Available model_types: {model_types}') - model_type = model_types[0] + model_type = model_info.model_type logger.info(f'Setting model_type: {model_type}') - quant_info = HfConfigFactory.get_quant_info(model_config) if torch_dtype is None: - torch_dtype = get_default_torch_dtype(model_config, quant_info) + torch_dtype = get_default_torch_dtype(model_info.torch_dtype) logger.info(f'Setting torch_dtype: {torch_dtype}') - - if quant_info is not None: - quant_info.pop('torch_dtype', None) - kwargs.update(quant_info) - + if model_info.quant_method is not None: + kwargs['quant_method'] = model_info.quant_method + kwargs['bits'] = model_info.bits kwargs.update({'model_type': model_type, 'attn_impl': attn_impl}) + model_info = MODEL_MAPPING[model_type] requires = model_info['requires'] for require in requires: diff --git a/swift/llm/model/utils.py b/swift/llm/model/utils.py index cc188b8d5..2644065d7 100644 --- a/swift/llm/model/utils.py +++ b/swift/llm/model/utils.py @@ -1,7 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import hashlib import os -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union import torch import torch.distributed as dist @@ -18,9 +19,46 @@ logger = get_logger() +@dataclass +class ModelInfo: + model_type: str + torch_dtype: torch.dtype + quant_method: Literal['gptq', 'awq', 'bnb', 'aqlm', None] = None + bits: int = 0 + + class HfConfigFactory: """This class is used to read config from config.json(maybe params.json also)""" + @staticmethod + def _get_torch_dtype(config: PretrainedConfig, quant_info: Dict[str, Any]) -> torch.dtype: + for key in ['torch_dtype', 'params_dtype']: + torch_dtype = HfConfigFactory.get_config_attr(config, key) + if torch_dtype is not None: + break + torch_dtype = HfConfigFactory._to_torch_dtype(torch_dtype) + if torch_dtype is None: + torch_dtype = quant_info.get('torch_dtype') + return torch_dtype + + @staticmethod + def get_model_info(config: PretrainedConfig, model_dir: str, model_type: Optional[str] = None) -> ModelInfo: + if model_type is None: + model_types = HfConfigFactory._get_matched_model_types(config, model_dir) + if len(model_types) > 1: + raise ValueError('Unable to obtain the accurate model_type based on the model architecture. ' + f'Please explicitly provide the model_type. Available model_types: {model_types}') + model_type = model_types[0] + + quant_info = HfConfigFactory._get_quant_info(config) + torch_dtype = HfConfigFactory._get_torch_dtype(config, quant_info) + res = ModelInfo(model_type, torch_dtype) + if quant_info is not None: + res.quant_method = quant_info['quant_method'] + res.bits = quant_info['bits'] + + return res + @staticmethod def _get_config_attrs(config: PretrainedConfig, attr_name: str) -> List[Tuple[PretrainedConfig, Any]]: res = [] @@ -50,14 +88,6 @@ def set_config_attr(config, attr_name: str, value: Any) -> None: for config, _ in attrs: setattr(config, attr_name, value) - @staticmethod - def get_torch_dtype(config) -> Optional[torch.dtype]: - for key in ['torch_dtype', 'params_dtype']: - torch_dtype = HfConfigFactory.get_config_attr(config, key) - if torch_dtype is None: - continue - return HfConfigFactory._to_torch_dtype(torch_dtype) - @staticmethod def get_max_model_len(config: PretrainedConfig) -> Optional[int]: """Get the max length supported by the model""" @@ -96,7 +126,7 @@ def _to_torch_dtype(torch_dtype: Union[str, torch.dtype]) -> torch.dtype: return torch_dtype @staticmethod - def get_quant_info(config: PretrainedConfig) -> Optional[Dict[str, Any]]: + def _get_quant_info(config: PretrainedConfig) -> Optional[Dict[str, Any]]: """Get quant_method, quant_bits, dtype. not support hqq/eetq now, support awq/gptq/bnb/aqlm""" quantization_config = getattr(config, 'quantization_config', None) if quantization_config is None: @@ -111,7 +141,7 @@ def get_quant_info(config: PretrainedConfig) -> Optional[Dict[str, Any]]: if bits is not None: res['bits'] = bits elif quant_method == 'bitsandbytes': - res['quant_method'] = quant_method + res['quant_method'] = 'bnb' load_in_4bit = quantization_config.get('load_in_4bit') load_in_8bit = quantization_config.get('load_in_8bit') bnb_4bit_compute_dtype = quantization_config.get('bnb_4bit_compute_dtype') @@ -123,7 +153,7 @@ def get_quant_info(config: PretrainedConfig) -> Optional[Dict[str, Any]]: return res or None @staticmethod - def get_matched_model_types(config: PretrainedConfig, model_dir: Optional[str] = None) -> List[str]: + def _get_matched_model_types(config: PretrainedConfig, model_dir: Optional[str] = None) -> List[str]: """Get possible model_type.""" # get possible model_types based on the model architecture. from .register import get_arch_mapping diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 6599daa99..19ca2e66c 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -5,7 +5,6 @@ from swift.utils import get_logger, get_main, seed_everything from ..argument import RLHFArguments from ..template import TEMPLATE_MAPPING -from .patcher import TrainTemplate from .sft import prepare_dataset, prepare_train_model_template, trainer_train logger = get_logger() @@ -20,9 +19,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]: logger.warning(f"Please check if args.template_type: '{args.template_type}' is correct.") msg = {} - model, ref_model, template, callbacks, optimizers = prepare_train_model_template(args) + model, ref_model, template, callbacks, optimizer_callback = prepare_train_model_template(args) with TrainerFactory.patch_template(args, template): - template = TrainTemplate(template) train_dataset, val_dataset = prepare_dataset(args, template, msg) return trainer_train( @@ -32,7 +30,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]: train_dataset, val_dataset, callbacks=callbacks, - optimizers=optimizers, + optimizers=optimizer_callback(model, train_dataset, args), msg=msg, ref_model=ref_model) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 3c27eaf60..e4f59e02c 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -13,26 +13,23 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import is_torch_npu_available, strtobool -from swift.llm import HfConfigFactory +from swift.llm import TEMPLATE_MAPPING, HfConfigFactory, get_model_tokenizer, get_template from swift.llm.argument import PtArguments, RLHFArguments, SftArguments -from swift.torchacc_utils import patch_acc_model from swift.trainers import TrainerFactory from swift.trainers.utils import can_return_loss, find_labels from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_logger, get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images, preprocess_logits_for_metrics, seed_everything, show_layers, use_torchacc) +from ...utils.torchacc_utils import patch_acc_model from ...utils.utils import get_time_info from ..dataset.loader import DatasetLoader from ..dataset.utils import (ConstantLengthDataset, LazyLLMDataset, dataset_map, print_example, sort_by_max_length, stat_dataset) -from ..model.model import get_model_tokenizer from ..template import Template -from ..template.base import get_template -from ..template.template import TEMPLATE_MAPPING from ..tuner import prepare_modules -from ..utils import set_generation_config +from ..utils.utils import set_generation_config from .accelerator import ta_accelerate -from .patcher import TrainTemplate, patch_ddp_mp, training_context +from .patcher import patch_ddp_mp logger = get_logger() @@ -83,9 +80,8 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]: args.model_type, model_id_or_path=args.model_id_or_path, revision=args.model_revision, load_model=False) # Loading Dataset - template: Template = get_template(args.template_type, tokenizer, args.system, args.max_length, - args.truncation_strategy) - template = TrainTemplate(template) + template: Template = get_template( + args.template_type, tokenizer, args.system, args.max_length, truncation_strategy=args.truncation_strategy) train_dataset, val_dataset = _get_train_val_dataset(args) td0, tkwargs0 = template.encode(train_dataset[0]) @@ -250,7 +246,7 @@ def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): model.label_names = label_names model.return_loss = return_loss - model, callbacks, optimizers = prepare_modules(model, args) + model, callbacks, optimizer_callback = prepare_modules(model, args) show_layers(model) logger.info(model) @@ -274,7 +270,7 @@ def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): args.bf16, args.fp16, gradient_checkpointing=True, - fsdp_flatten_parameters=(args.sft_type == 'full')) + fsdp_flatten_parameters=(args.train_type == 'full')) template_kwargs = {'loss_scale': args.loss_scale, 'tools_prompt': args.tools_prompt} if args.sequence_parallel_size and args.sequence_parallel_size > 1: @@ -285,8 +281,7 @@ def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): tokenizer, args.system, args.max_length, - args.truncation_strategy, - model=model, + truncation_strategy=args.truncation_strategy, **template_kwargs) template._is_training = True template.encode = partial( @@ -296,11 +291,11 @@ def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): logger.info(f'args.lazy_tokenize: {args.lazy_tokenize}') if not isinstance(args, RLHFArguments): - return model, template, callbacks, optimizers + return model, template, callbacks, optimizer_callback # ref_model ref_model = None - if not args.ref_model_free and (args.ref_model_type or args.sft_type == 'full'): + if not args.ref_model_free and (args.ref_model_type or args.train_type == 'full'): if args.ref_model_type: kwargs['model_id_or_path'] = args.ref_model_id_or_path kwargs['revision'] = args.ref_model_revision @@ -318,7 +313,7 @@ def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): ref_model.requires_grad_(False).eval() template.ref_model = ref_model - return model, ref_model, template, callbacks, optimizers + return model, ref_model, template, callbacks, optimizer_callback def prepare_dataset(args, template: Template, msg: Optional[Dict[str, Any]] = None): @@ -398,7 +393,7 @@ def trainer_train(args, if msg is None: msg = {} training_args = args.training_args - padding_to = args.max_length if args.sft_type == 'longlora' else None + padding_to = args.max_length if args.train_type == 'longlora' else None tokenizer = template.tokenizer data_collator = partial(template.data_collator, padding_to=padding_to) @@ -453,8 +448,8 @@ def trainer_train(args, json.dump(check_json_format(args_obj.__dict__), f, ensure_ascii=False, indent=2) logging_path = os.path.join(args.output_dir, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') - with training_context([model] if ref_model is None else [model, ref_model], - [template] if ref_model is None else [template, template]): + with template.training_context([model] if ref_model is None else [model, ref_model], + [template] if ref_model is None else [template, template]): trainer.train(training_args.resume_from_checkpoint) last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None) logger.info(f'last_model_checkpoint: {last_model_checkpoint}') @@ -503,11 +498,17 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: if args.train_backend == 'megatron': return llm_sft_megatron(args) msg = {} - model, template, callbacks, optimizers = prepare_train_model_template(args, msg) - template = TrainTemplate(template) + model, template, callbacks, optimizer_callback = prepare_train_model_template(args, msg) train_dataset, val_dataset = prepare_dataset(args, template, msg) return trainer_train( - args, model, template, train_dataset, val_dataset, callbacks=callbacks, optimizers=optimizers, msg=msg) + args, + model, + template, + train_dataset, + val_dataset, + callbacks=callbacks, + optimizers=optimizer_callback(model, train_dataset, args), + msg=msg) def get_sft_main(args, llm): diff --git a/swift/llm/tuner.py b/swift/llm/tuner.py index 2ecf37442..114e14f38 100644 --- a/swift/llm/tuner.py +++ b/swift/llm/tuner.py @@ -4,6 +4,7 @@ from typing import List import json +import numpy as np import torch import transformers from packaging import version @@ -29,7 +30,7 @@ def handle_target_modules(model, args: SftArguments) -> None: args: The SftArguments """ - if args.sft_type == 'ia3': + if args.train_type == 'ia3': assert len(args.ia3_feedforward_modules) > 0, ('Setting ia3_target_modules to `ALL` ' 'need to pass MLP linear names to `ia3_feedforward_modules`') target_modules: List[str] = args.target_modules @@ -93,7 +94,7 @@ def on_train_begin(self, _args, state, control, **kwargs): model = kwargs['model'] if hasattr(model, 'set_active_adapters'): model.set_active_adapters(model.adapters.keys(), offload='cpu') - if self.args.sft_type == 'adalora': + if self.args.train_type == 'adalora': model.peft_config['default'].total_step = state.max_steps def zero_grad(_self, *args, **kwargs): @@ -104,7 +105,7 @@ def zero_grad(_self, *args, **kwargs): model.zero_grad = types.MethodType(zero_grad, model) def on_step_end(self, _args, state, control, **kwargs): - if self.args.sft_type == 'adalora': + if self.args.train_type == 'adalora': self.global_step = state.global_step @@ -161,7 +162,7 @@ def prepare_modules(model, args: SftArguments): group_type = args.get_model_group() # Preparing LoRA - if args.is_adapter(): + if args.is_adapter: if args.resume_from_checkpoint is None: handle_target_modules(model, args) if args.init_lora_weights and args.init_lora_weights.lower() in ('true', 'false'): @@ -182,7 +183,7 @@ def prepare_modules(model, args: SftArguments): 'init_lora_weights': args.init_lora_weights, } - if args.sft_type in ('lora', 'longlora'): + if args.train_type in ('lora', 'longlora'): # Fix the name of the layer in xcomposer that contains Plora. if any(['lora_' in n for n, p in model.named_parameters()]): model.requires_grad_(False) @@ -198,7 +199,7 @@ def prepare_modules(model, args: SftArguments): logger.info(f'lora_config: {lora_config}') elif args.tuner_backend == 'unsloth': from unsloth import FastLanguageModel - assert args.sft_type == 'lora', 'Unsloth does not support LongLoRA' + assert args.train_type == 'lora', 'Unsloth does not support LongLoRA' lora_kwargs.pop('lorap_lr_ratio') model = FastLanguageModel.get_peft_model( model, @@ -207,13 +208,13 @@ def prepare_modules(model, args: SftArguments): **lora_kwargs, ) logger.info(f'unsloth_config: {lora_kwargs}') - if args.sft_type == 'longlora': + if args.train_type == 'longlora': assert LongLoRAModelType.LLAMA in args.model_type assert version.parse(transformers.__version__) >= version.parse('4.39.3') from swift.tuners.longlora.llama import replace_llama_attn replace_llama_attn(model) model.config.group_size_ratio = 0.25 - elif args.sft_type == 'adalora': + elif args.train_type == 'adalora': lora_kwargs.pop('lorap_lr_ratio', None) lora_kwargs['rank_pattern'] = None adalora_config = AdaLoraConfig( @@ -230,7 +231,7 @@ def prepare_modules(model, args: SftArguments): ) model = Swift.prepare_model(model, adalora_config) logger.info(f'adalora_config: {adalora_config}') - elif args.sft_type == 'ia3': + elif args.train_type == 'ia3': ia3_config = IA3Config( task_type='CAUSAL_LM', target_modules=args.target_modules, @@ -239,14 +240,14 @@ def prepare_modules(model, args: SftArguments): ) model = Swift.prepare_model(model, ia3_config) logger.info(f'ia3_config: {ia3_config}') - elif args.sft_type == 'llamapro': + elif args.train_type == 'llamapro': llamapro_config = LLaMAProConfig( model_type=group_type, num_new_blocks=args.llamapro_num_new_blocks, num_groups=args.llamapro_num_groups) model = Swift.prepare_model(model, llamapro_config) logger.info(f'llamapro_config: {llamapro_config}') - elif args.sft_type == 'adapter': + elif args.train_type == 'adapter': assert group_type in MODEL_KEYS_MAPPING mlp_key = MODEL_KEYS_MAPPING[group_type].mlp mlp_key = mlp_key.split('.{}.')[1] @@ -258,7 +259,7 @@ def prepare_modules(model, args: SftArguments): act_layer=args.adapter_act) model = Swift.prepare_model(model, adapter_config) logger.info(f'adapter_config: {adapter_config}') - elif args.sft_type == 'vera': + elif args.train_type == 'vera': vera_config = VeraConfig( r=args.vera_rank, target_modules=args.target_modules, @@ -270,7 +271,7 @@ def prepare_modules(model, args: SftArguments): vera_config = handle_vera_target_modules(model, vera_config) model = Swift.prepare_model(model, vera_config) logger.info(f'vera_config: {vera_config}') - elif args.sft_type == 'boft': + elif args.train_type == 'boft': boft_config = BOFTConfig( boft_block_size=args.boft_block_size, boft_block_num=args.boft_block_num, @@ -281,7 +282,7 @@ def prepare_modules(model, args: SftArguments): ) model = Swift.prepare_model(model, boft_config) logger.info(f'boft_config: {boft_config}') - elif args.sft_type == 'fourierft': + elif args.train_type == 'fourierft': from peft import FourierFTConfig fourier_config = FourierFTConfig( target_modules=args.target_modules, @@ -291,7 +292,7 @@ def prepare_modules(model, args: SftArguments): ) model = Swift.prepare_model(model, fourier_config) logger.info(f'fourier_config: {fourier_config}') - elif args.sft_type == 'reft': + elif args.train_type == 'reft': reft_config = ReftConfig( model_type=group_type, layer_key=args.reft_layer_key, @@ -324,11 +325,11 @@ def prepare_modules(model, args: SftArguments): logger.info('Convert trainable parameters from fp16 to fp32.') is_logging = True p.data = p.data.to(dtype=torch.float32) - elif args.sft_type in extra_tuners: - tuner: Tuner = extra_tuners[args.sft_type] + elif args.train_type in extra_tuners: + tuner: Tuner = extra_tuners[args.train_type] model = tuner.prepare_model(model, args) model.is_tuner_plugin = True - elif args.sft_type == 'full': + elif args.train_type == 'full': model.train() model.requires_grad_(True) @@ -362,7 +363,7 @@ def prepare_modules(model, args: SftArguments): logger.warning( f'There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.') else: - raise ValueError(f'args.sft_type: {args.sft_type}') + raise ValueError(f'args.train_type: {args.train_type}') if args.sequence_parallel_size > 1: from swift.trainers.xtuner import dispatch_module_xtuner @@ -372,31 +373,15 @@ def prepare_modules(model, args: SftArguments): if args.lora_lr_ratio: optimizer_callback = optimizers_map['lorap'] if args.use_galore: - from swift.trainers.optimizers.galore import GaLoreConfig if args.galore_target_modules is None: args.galore_target_modules = find_all_linears(model, 0, args.model_type, args.quant_method) if args.galore_with_embedding: args.galore_target_modules += find_embedding(model) - args.training_args.galore_config = GaLoreConfig( - target_modules=args.galore_target_modules, - rank=args.galore_rank, - update_proj_gap=args.galore_update_proj_gap, - galore_scale=args.galore_scale, - proj_type=args.galore_proj_type, - optim_per_parameter=args.galore_optim_per_parameter, - quantize=args.galore_quantization, - proj_quant=args.galore_proj_quant, - proj_bits=args.galore_proj_bits, - proj_group_size=args.galore_proj_group_size, - cos_threshold=args.galore_cos_threshold, - gamma_proj=args.galore_gamma_proj, - queue_size=args.galore_queue_size, - ) optimizer_callback = optimizers_map['galore'] callbacks = [] if args.lisa_activated_layers > 0: - assert args.sft_type == 'full', 'LISA only supports full parameter training.' + assert args.train_type == 'full', 'LISA only supports full parameter training.' lisa_callback = DynamicLayerActivationCallback( n_layers=args.lisa_activated_layers, # Number of layers to activate step_interval=args.lisa_step_interval, # Step interval to update active layers @@ -404,7 +389,7 @@ def prepare_modules(model, args: SftArguments): lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value callbacks.append(lisa_callback) - if args.is_adapter() and args.tuner_backend == 'swift': + if args.is_adapter and args.tuner_backend == 'swift': callbacks.append(TrainerAdapterCallback(args)) callbacks.extend(extra_callbacks or []) - return model, callbacks, optimizer_callback(model, args) + return model, callbacks, optimizer_callback diff --git a/swift/plugin/optimizer.py b/swift/plugin/optimizer.py index da3168166..d2b1e7660 100644 --- a/swift/plugin/optimizer.py +++ b/swift/plugin/optimizer.py @@ -1,11 +1,27 @@ +import math + from transformers import Trainer -from swift.llm.utils import calculate_max_steps from swift.trainers.optimizers.galore import create_optimizer_and_scheduler +from swift.utils import get_dist_setting + + +def calculate_max_steps(dataset, args: 'SftArguments') -> int: + if args.max_steps: + max_steps = args.max_steps + else: + assert not args.streaming + len_dataset = len(dataset) + _, _, world_size, _ = get_dist_setting() + total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size + num_update_steps_per_epoch = len_dataset // total_train_batch_size + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + return max_steps -def create_galore_optimizers(model, args): - training_steps = calculate_max_steps(args) +def create_galore_optimizers(model, dataset, args): + training_steps = calculate_max_steps(dataset, args) return create_optimizer_and_scheduler( model, args.training_args, @@ -15,8 +31,7 @@ def create_galore_optimizers(model, args): weight_decay=args.weight_decay) -def create_lorap_optimizers(model, args): - training_steps = calculate_max_steps(args) +def create_lorap_optimizers(model, dataset, args): args = args.training_args optimizer_grouped_parameters = None if hasattr(model, 'create_optimizer_param_groups'): @@ -41,7 +56,7 @@ def create_lorap_optimizers(model, args): return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None -def default_create_optimizers(model, args): +def default_create_optimizers(model, dataset, args): return None, None diff --git a/swift/plugin/tuner.py b/swift/plugin/tuner.py index 90a54b52d..ba1170d2f 100644 --- a/swift/plugin/tuner.py +++ b/swift/plugin/tuner.py @@ -1,6 +1,6 @@ from peft import IA3Config, PeftModel, get_peft_model -from swift.utils.module_mapping import MODEL_KEYS_MAPPING, ModelKeys +from swift.llm.module_mapping import MODEL_KEYS_MAPPING, ModelKeys from swift.utils.torch_utils import find_all_linears diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index a1921bdbc..76faaab14 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -136,7 +136,7 @@ def _add_adapter_cfg(self, output_dir: str) -> None: if not hasattr(self, 'sft_args'): return sft_args = self.sft_args - if sft_args.sft_type == 'full': + if sft_args.train_type == 'full': return configuration_path = os.path.join(output_dir, 'configuration.json') new_cfg = {} @@ -262,8 +262,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): save_safetensors = self.args.save_safetensors sft_args = getattr(self, 'sft_args', None) - if sft_args and sft_args.sft_type in extra_tuners: - tuner: Tuner = extra_tuners[sft_args.sft_type] + if sft_args and sft_args.train_type in extra_tuners: + tuner: Tuner = extra_tuners[sft_args.train_type] tuner.save_pretrained(self.model, output_dir) elif not isinstance(self.model, supported_classes): if state_dict is None: @@ -284,14 +284,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): else: self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors) # tokenizer - if self.tokenizer is not None and sft_args is not None and sft_args.sft_type == 'full': + if self.tokenizer is not None and sft_args is not None and sft_args.train_type == 'full': if hasattr(self.tokenizer, 'processor'): self.tokenizer.processor.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir) # training_args.bin torch.save(self.args, os.path.join(output_dir, 'training_args.bin')) # additional files - if sft_args is not None and sft_args.sft_type == 'full': + if sft_args is not None and sft_args.train_type == 'full': # TODO:additional_saved_files additional_files = getattr(self.args, 'additional_saved_files', None) or [] + ['preprocessor_config.json'] if model_dir is not None: diff --git a/swift/tuners/llamapro.py b/swift/tuners/llamapro.py index 84b6dd452..20397ed33 100644 --- a/swift/tuners/llamapro.py +++ b/swift/tuners/llamapro.py @@ -6,8 +6,8 @@ import torch from torch import nn +from swift.llm.module_mapping import MODEL_KEYS_MAPPING, ModelKeys from swift.utils.logger import get_logger -from swift.utils.module_mapping import MODEL_KEYS_MAPPING, ModelKeys from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput logger = get_logger() diff --git a/swift/tuners/utils.py b/swift/tuners/utils.py index 2353db422..95b30804e 100644 --- a/swift/tuners/utils.py +++ b/swift/tuners/utils.py @@ -20,9 +20,9 @@ from peft.utils import ModulesToSaveWrapper as _ModulesToSaveWrapper from peft.utils import _get_submodules +from swift.llm.module_mapping import MODEL_KEYS_MAPPING, ModelKeys from swift.utils.constants import BIN_EXTENSIONS from swift.utils.logger import get_logger -from swift.utils.module_mapping import MODEL_KEYS_MAPPING, ModelKeys logger = get_logger() diff --git a/swift/ui/llm_eval/llm_eval.py b/swift/ui/llm_eval/llm_eval.py index d73ec5c72..70864117c 100644 --- a/swift/ui/llm_eval/llm_eval.py +++ b/swift/ui/llm_eval/llm_eval.py @@ -187,4 +187,4 @@ def eval_model(cls, *args): run_command, eval_args, log_file = cls.eval(*args) os.system(run_command) time.sleep(2) - return gr.update(open=True), EvalRuntime.refresh_tasks(log_file), [eval_args.sft_type] + return gr.update(open=True), EvalRuntime.refresh_tasks(log_file), [eval_args.train_type] diff --git a/swift/ui/llm_export/llm_export.py b/swift/ui/llm_export/llm_export.py index f594886fa..9056afa38 100644 --- a/swift/ui/llm_export/llm_export.py +++ b/swift/ui/llm_export/llm_export.py @@ -191,4 +191,4 @@ def export_model(cls, *args): run_command, export_args, log_file = cls.export(*args) os.system(run_command) time.sleep(2) - return gr.update(open=True), ExportRuntime.refresh_tasks(log_file), [export_args.sft_type] + return gr.update(open=True), ExportRuntime.refresh_tasks(log_file), [export_args.train_type] diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index df3f6aa7b..8365e7bca 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -308,7 +308,7 @@ def deploy_model(cls, *args): gr.Info(cls.locale('load_alert', cls.lang)['value']) time.sleep(2) return gr.update(open=True), Runtime.refresh_tasks(log_file), [ - deploy_args.model_type, deploy_args.template_type, deploy_args.sft_type + deploy_args.model_type, deploy_args.template_type, deploy_args.train_type ] @classmethod