From 9eb2779556c4c5ce7cc4968638b7398994a19e51 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Mon, 14 Oct 2024 14:08:57 +0800 Subject: [PATCH 1/4] update --- swift/llm/utils/template.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index df559de1a..a3b339a95 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1549,7 +1549,8 @@ def _process_image_qwen(image): class _Qwen2VLTemplateMixin: - + image_token_id = 151655 + video_token_id = 151656 def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, example: Dict[str, Any]) -> List[Context]: assert media_type in {'image', 'video'} @@ -1598,13 +1599,13 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An for media_type in ['images', 'videos']: if locals()[media_type]: if media_type == 'images': - media_token = 151655 + media_token = self.image_token_id media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt') media_grid_thw = media_inputs['image_grid_thw'] else: media_inputs = processor.image_processor(images=None, videos=videos, return_tensors='pt') media_grid_thw = media_inputs['video_grid_thw'] - media_token = 151656 + media_token = self.video_token_id idx_list = _findall(input_ids, media_token) added_tokens_len = 0 for i, idx in enumerate(idx_list): From 9e78215e4441e6df48a7be50585f1c417e9c8cfc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 14 Oct 2024 14:25:26 +0800 Subject: [PATCH 2/4] update --- swift/llm/utils/template.py | 62 ++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index a3b339a95..c464b50df 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1596,6 +1596,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An labels = inputs['labels'] images = example.get('images') or [] videos = example.get('videos') or [] + data = {} for media_type in ['images', 'videos']: if locals()[media_type]: if media_type == 'images': @@ -1618,32 +1619,51 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:] added_tokens_len += token_len - 1 - inputs.update(media_inputs) + data.update(media_inputs) - inputs['input_ids'] = input_ids + inputs['input_ids'] =input_ids inputs['labels'] = labels - inputs['_data'] = {'plain_text': not images and not videos, 'input_ids': torch.tensor(input_ids)[None]} + data['input_ids'] = torch.tensor(input_ids)[None] + inputs['_data'] = _data return inputs, {} def _post_encode(self, model, data: Any) -> Dict[str, Any]: - plain_text = data.pop('plain_text', False) - if is_deepspeed_enabled() and plain_text: - from PIL import Image - images = [Image.new('RGB', (32, 32), (0, 0, 0))] - processor = self.tokenizer.processor - media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt') - input_ids = data['input_ids'] - device = input_ids.device - pixel_values = media_inputs['pixel_values'].to(device) - _model = model.model - if not hasattr(_model, 'embed_tokens'): - _model = _model.model # LoRA - inputs_embeds = _model.embed_tokens(input_ids) - pixel_values = pixel_values.type(model.visual.get_dtype()) - image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) - inputs_embeds += image_embeds.mean() * 0. - return {'inputs_embeds': inputs_embeds[0]} - return {} + _model = model.model + if not hasattr(_model, 'embed_tokens'): + _model = _model.model # LoRA + input_ids = data['input_ids'] + pixel_values = data.get('pixel_values') + pixel_values_videos = data.get('pixel_values_videos') + inputs_embeds = _model.embed_tokens(input_ids) + if pixel_values is None and pixel_values_videos is None: # plain-text + if is_deepspeed_enabled(): + from PIL import Image + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + processor = self.tokenizer.processor + media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt') + device = input_ids.device + pixel_values = media_inputs['pixel_values'].to(device) + + pixel_values = pixel_values.type(model.visual.get_dtype()) + image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw']) + inputs_embeds += image_embeds.mean() * 0. + else: + if pixel_values is not None: + image_grid_thw = data['image_grid_thw'] + pixel_values = pixel_values.type(_model.visual.get_dtype()) + image_embeds = _model.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = (input_ids == _model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_grid_thw = data['video_grid_thw'] + pixel_values_videos = pixel_values_videos.type(_model.visual.get_dtype()) + video_embeds = _model.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == _model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + return {'inputs_embeds': inputs_embeds[0]} def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]: res = super().data_collator(batch, padding_to) From 799a3593ead77e8c7bf5c638918d4212f838c3eb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 14 Oct 2024 14:26:00 +0800 Subject: [PATCH 3/4] update --- swift/llm/utils/template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index c464b50df..f1b3d2d8d 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1551,6 +1551,7 @@ def _process_image_qwen(image): class _Qwen2VLTemplateMixin: image_token_id = 151655 video_token_id = 151656 + def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, example: Dict[str, Any]) -> List[Context]: assert media_type in {'image', 'video'} @@ -1621,7 +1622,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An added_tokens_len += token_len - 1 data.update(media_inputs) - inputs['input_ids'] =input_ids + inputs['input_ids'] = input_ids inputs['labels'] = labels data['input_ids'] = torch.tensor(input_ids)[None] inputs['_data'] = _data From 7b299193e3845f3493cc95d905f9ae26464789d0 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Mon, 14 Oct 2024 14:59:15 +0800 Subject: [PATCH 4/4] fix --- ...273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- .../Instruction/Command-line-parameters.md | 2 +- swift/llm/export.py | 2 +- swift/llm/utils/template.py | 14 +++++++------- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 892f36a3f..d67257ede 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -347,7 +347,7 @@ RLHF参数继承了sft参数, 除此之外增加了以下参数: - `--bnb_4bit_quant_type`: 默认值为`'nf4'`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效. - `--bnb_4bit_use_double_quant`: 默认值为`True`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效. - `--bnb_4bit_quant_storage`: 默认值为`True`. 具体的参数介绍可以在`sft命令行参数`中查看. 若`quantization_bit`设置为0, 则该参数失效. -- `--🔥max_new_tokens`: 生成新token的最大数量, 默认值为`2048`. +- `--🔥max_new_tokens`: 生成新token的最大数量, 默认值为`2048`. 如果使用部署, 请通过在客户端传入`max_tokens`来控制最大生成的tokens数. - `--🔥do_sample`: 参考文档: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). 默认值为`None`, 继承模型的generation_config. - `--temperature`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用. - `--top_k`: 默认值为`None`, 继承模型的generation_config. 该参数只有在`do_sample`设置为True时才生效. 该参数会在部署参数中作为默认值使用. diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 4d3f0d465..8e3ff7955 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -348,7 +348,7 @@ RLHF parameters are an extension of the sft parameters, with the addition of the - `--bnb_4bit_quant_type`: Default is `'nf4'`. See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect. - `--bnb_4bit_use_double_quant`: Default is `True`. See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect. - `--bnb_4bit_quant_storage`: Default value `None`.See `sft command line arguments` for parameter details. If `quantization_bit` is set to 0, this parameter has no effect. -- `--🔥max_new_tokens`: Maximum number of new tokens to generate, default is `2048`. +- `--🔥max_new_tokens`: Maximum number of new tokens to generate, default is `2048`. If using deployment, please control the maximum number of generated tokens by passing `max_tokens` in the client. - `--🔥do_sample`: Reference document: [https://huggingface.co/docs/transformers/main_classes/text_generation](https://huggingface.co/docs/transformers/main_classes/text_generation). Default is `None`, inheriting the model's generation_config. - `--temperature`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters. - `--top_k`: Default is `None`, inheriting the model's generation_config. This parameter only takes effect when `do_sample` is set to True. This parameter will be used as default value in deployment parameters. diff --git a/swift/llm/export.py b/swift/llm/export.py index 0f85a7c5e..05a4bb8e3 100644 --- a/swift/llm/export.py +++ b/swift/llm/export.py @@ -141,7 +141,7 @@ def get_block_name_to_quantize(model: nn.Module, model_type: str) -> Optional[st if module_lists: module_list = max(module_lists, key=lambda x: len(x[1])) _patch_model_forward(module_list[1]) - return f'{prefix}.{module_list[0]}' + return f'{prefix}.{module_list[0]}' if prefix else module_list[0] def gptq_model_quantize(model, tokenizer, batch_size): diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index f1b3d2d8d..0e4029b99 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -1625,7 +1625,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs['input_ids'] = input_ids inputs['labels'] = labels data['input_ids'] = torch.tensor(input_ids)[None] - inputs['_data'] = _data + inputs['_data'] = data return inputs, {} def _post_encode(self, model, data: Any) -> Dict[str, Any]: @@ -1651,17 +1651,17 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]: else: if pixel_values is not None: image_grid_thw = data['image_grid_thw'] - pixel_values = pixel_values.type(_model.visual.get_dtype()) - image_embeds = _model.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = (input_ids == _model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + pixel_values = pixel_values.type(model.visual.get_dtype()) + image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw) + image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_grid_thw = data['video_grid_thw'] - pixel_values_videos = pixel_values_videos.type(_model.visual.get_dtype()) - video_embeds = _model.visual(pixel_values_videos, grid_thw=video_grid_thw) - video_mask = (input_ids == _model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + pixel_values_videos = pixel_values_videos.type(model.visual.get_dtype()) + video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) return {'inputs_embeds': inputs_embeds[0]}