From f3f955a064d1474f62c164814c24b6c7e42b800c Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Mon, 23 Oct 2023 14:38:44 +0800 Subject: [PATCH] Support LLaVA 1.5 --- configs/llava/llava-7b-v1.5_caption.py | 80 ++++++++ configs/llava/llava-7b-v1.5_vqa.py | 75 +++++++ configs/llava/metafile.yml | 26 +++ mmpretrain/models/multimodal/llava/llava.py | 34 ++-- mmpretrain/models/multimodal/llava/modules.py | 188 ++++++++---------- 5 files changed, 290 insertions(+), 113 deletions(-) create mode 100644 configs/llava/llava-7b-v1.5_caption.py create mode 100644 configs/llava/llava-7b-v1.5_vqa.py diff --git a/configs/llava/llava-7b-v1.5_caption.py b/configs/llava/llava-7b-v1.5_caption.py new file mode 100644 index 00000000000..e29c84de3c1 --- /dev/null +++ b/configs/llava/llava-7b-v1.5_caption.py @@ -0,0 +1,80 @@ +_base_ = '../_base_/default_runtime.py' + +meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501 +im_patch_token = '' +patch_size = 14 +image_size = 224 +num_patches = (image_size // patch_size)**2 +caption_prompt = f'''{meta_prompt} User: +Describe the image in detail. ASSISTANT:''' + +# model settings +model = dict( + type='Llava', + tokenizer=dict( + type='AutoTokenizer', + name_or_path='liuhaotian/llava-v1.5-7b'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained=( + 'https://download.openmmlab.com/mmclassification/v0/clip/' + 'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'), + ), + mm_hidden_size=1024, + use_im_patch=False, + use_im_start_end=False, + mm_proj_depth=2, + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='huggyllama/llama-7b', + ), + task='caption', + prompt_tmpl=caption_prompt, + generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0), +) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(image_size, image_size), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id']), +] + +test_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline, + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +# schedule settings +test_cfg = dict() diff --git a/configs/llava/llava-7b-v1.5_vqa.py b/configs/llava/llava-7b-v1.5_vqa.py new file mode 100644 index 00000000000..5d49ef5995c --- /dev/null +++ b/configs/llava/llava-7b-v1.5_vqa.py @@ -0,0 +1,75 @@ +_base_ = '../_base_/default_runtime.py' + +meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501 +image_size = 336 +prompt_tmpl = f'''{meta_prompt} User: +{{question}} ASSISTANT:''' + +# model settings +model = dict( + type='Llava', + tokenizer=dict( + type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + img_size=image_size, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained=('/mnt/petrelfs/mazerun/mmlab/pretrain/vit-l-p14-336px.pth'), + ), + mm_hidden_size=1024, + use_im_patch=False, + use_im_start_end=False, + mm_proj_depth=2, + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='huggyllama/llama-7b', + ), + task='vqa', + prompt_tmpl=prompt_tmpl, + generation_cfg=dict(max_new_tokens=100), +) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(image_size, image_size), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id', 'question']), +] + +test_dataloader = dict( + batch_size=8, + num_workers=5, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline, + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +# schedule settings +test_cfg = dict() diff --git a/configs/llava/metafile.yml b/configs/llava/metafile.yml index 2b3cfc4dbae..73b2abc8e76 100644 --- a/configs/llava/metafile.yml +++ b/configs/llava/metafile.yml @@ -23,3 +23,29 @@ Models: CIDER: null Weights: null Config: configs/llava/llava-7b-v1_caption.py + - Name: llava-7b-v1.5_caption + Metadata: + FLOPs: null + Parameters: 7045816320 + In Collection: LLaVA + Results: + - Task: Image Caption + Dataset: COCO + Metrics: + BLEU-4: null + CIDER: null + Weights: null + Config: configs/llava/llava-7b-v1.5_caption.py + - Name: llava-7b-v1.5_vqa + Metadata: + FLOPs: null + Parameters: 7045816320 + In Collection: LLaVA + Results: + - Task: Visual Question Answering + Dataset: COCO + Metrics: + BLEU-4: null + CIDER: null + Weights: null + Config: configs/llava/llava-7b-v1.5_vqa.py diff --git a/mmpretrain/models/multimodal/llava/llava.py b/mmpretrain/models/multimodal/llava/llava.py index 103d81296f0..6cc72e46060 100644 --- a/mmpretrain/models/multimodal/llava/llava.py +++ b/mmpretrain/models/multimodal/llava/llava.py @@ -24,8 +24,8 @@ class Llava(BaseModel): use_im_start_end (bool): Whether to use the im_start and im_end tokens mm_vision_select_layer (int): The index from vision encoder output. Defaults to -1. - use_mm_proj (bool): Whether to enable multi-modal projection. - Defaults to True. + mm_proj_depth (int): The number of linear layers for multi-modal + projection. Defaults to 1. load_lang_pretrained (bool): Whether to load the pretrained model of language encoder. Defaults to False. generation_cfg (dict): The extra generation config, accept the keyword @@ -51,9 +51,10 @@ def __init__(self, mm_hidden_size: int, prompt_tmpl: str, task: str = 'caption', + use_im_patch: bool = True, use_im_start_end: bool = False, mm_vision_select_layer: int = -1, - use_mm_proj: bool = True, + mm_proj_depth: int = 1, generation_cfg: dict = dict(), load_lang_pretrained: bool = False, data_preprocessor: Optional[dict] = None, @@ -75,7 +76,8 @@ def __init__(self, # init tokenizer self.tokenizer = TOKENIZER.build(tokenizer) # add Llava special tokens to the tokenizer - self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) + if use_im_patch: + self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True) if use_im_start_end: self.tokenizer.add_tokens([self.im_start_token, self.im_end_token], special_tokens=True) @@ -108,14 +110,12 @@ def __init__(self, vision_encoder=vision_encoder, lang_encoder=lang_encoder, mm_hidden_size=mm_hidden_size, - use_mm_proj=use_mm_proj, + mm_proj_depth=mm_proj_depth, use_im_start_end=use_im_start_end, im_start_token=self.tokenizer.convert_tokens_to_ids( self.im_start_token), im_end_token=self.tokenizer.convert_tokens_to_ids( self.im_end_token), - im_patch_token=self.tokenizer.convert_tokens_to_ids( - self.im_patch_token), mm_vision_select_layer=mm_vision_select_layer) self.generation_cfg = generation_cfg @@ -207,16 +207,24 @@ def preprocess_text(self, data_samples: List[DataSample], Returns: List[DataSample]: Return list of data samples. """ - prompts = [] + tokens = [] for sample in data_samples: - final_prompt = self.prompt_tmpl.format(**sample.to_dict()) - prompts.append(final_prompt) + prompt = self.prompt_tmpl.format(**sample.to_dict()) + input_ids = [] + while '' in prompt: + prefix, _, prompt = prompt.partition('') + input_ids.extend( + self.tokenizer(prefix, add_special_tokens=False).input_ids) + input_ids.append(-200) + if prompt: + input_ids.extend( + self.tokenizer(prompt, add_special_tokens=False).input_ids) + tokens.append(dict(input_ids=input_ids)) self.tokenizer.padding_side = 'left' - input_text = self.tokenizer( - prompts, + input_text = self.tokenizer.pad( + tokens, padding='longest', - truncation=True, return_tensors='pt', max_length=2000, ).to(device) diff --git a/mmpretrain/models/multimodal/llava/modules.py b/mmpretrain/models/multimodal/llava/modules.py index afa6eefadcb..ea82b5ca66d 100644 --- a/mmpretrain/models/multimodal/llava/modules.py +++ b/mmpretrain/models/multimodal/llava/modules.py @@ -31,10 +31,10 @@ def __init__(self, lang_encoder, mm_hidden_size, use_im_start_end=True, - use_mm_proj=True, + mm_proj_depth=1, im_start_token: Optional[int] = None, im_end_token: Optional[int] = None, - im_patch_token: Optional[int] = None, + im_token_index: int = -200, mm_vision_select_layer: int = -1): super().__init__(lang_encoder.config) self.vision_tower = vision_encoder @@ -43,16 +43,25 @@ def __init__(self, self.use_im_start_end = use_im_start_end self.im_start_token = im_start_token self.im_end_token = im_end_token - self.im_patch_token = im_patch_token self.mm_hidden_size = mm_hidden_size self.mm_vision_select_layer = mm_vision_select_layer + self.im_token_index = im_token_index self.lang_hidden_size = lang_encoder.config.hidden_size - if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'): + if mm_proj_depth == 1: + # Llava V1 mm_projector = nn.Linear(self.mm_hidden_size, self.lang_hidden_size) self.lang_encoder.model.add_module('mm_projector', mm_projector) - elif not use_mm_proj: + elif mm_proj_depth > 1: + # Llava V1.5 + modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)] + for _ in range(1, mm_proj_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(self.lang_hidden_size, self.lang_hidden_size)) + mm_projector = nn.Sequential(*modules) + self.lang_encoder.model.add_module('mm_projector', mm_projector) + elif mm_proj_depth == 0: self.lang_encoder.model.add_module('mm_projector', nn.Identity()) self.post_init() @@ -80,16 +89,12 @@ def forward( return_dict if return_dict is not None else self.config.use_return_dict) - # decoder outputs consists of - # (dec_features, layer_state, dec_hidden, dec_attn) - if inputs_embeds is None: - inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids) - - inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds, - images) + (input_ids, attention_mask, past_key_values, inputs_embeds, + labels) = self.forward_vision_tower(input_ids, attention_mask, + past_key_values, labels, images) return self.lang_encoder( - input_ids=None, + input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -127,106 +132,86 @@ def prepare_inputs_for_generation(self, def forward_vision_tower( self, input_ids: torch.LongTensor, - inputs_embeds: torch.FloatTensor, - images: Union[torch.FloatTensor, list, None] = None, + attention_mask: torch.LongTensor, + past_key_values: torch.FloatTensor, + labels: torch.LongTensor, + images: Union[torch.FloatTensor, None] = None, ): - if self.use_im_start_end: - assert self.im_start_token is not None - assert self.im_end_token is not None - if images is not None: - assert self.im_patch_token is not None - - if self.vision_tower is None or images is None or ( - input_ids.shape[1] == 1 and not self.training): - return inputs_embeds + if self.vision_tower is None or images is None or input_ids.shape[1] == 1: + if past_key_values is not None and self.vision_tower is not None and images is not None and input_ids.shape[1] == 1: + attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) + return input_ids, attention_mask, past_key_values, None, labels with torch.no_grad(): - if isinstance(images, (list, tuple)): - # variable length images - image_features = [] - for image in images: - feats = self.vision_tower(image.unsqueeze(0)) - image_feature = feats[self.mm_vision_select_layer][:, 1:] - image_features.append(image_feature) - else: - feats = self.vision_tower(images) - image_features = feats[self.mm_vision_select_layer][:, 1:] - - mm_projector = self.lang_encoder.model.mm_projector - if isinstance(images, (list, tuple)): - image_features = [ - mm_projector(image_feature)[0] - for image_feature in image_features - ] - else: - image_features = mm_projector(image_features) + # TODO: support variable number of images (single now) + feats = self.vision_tower(images) + image_features = feats[-1][:, 1:] - dummy_image_features = torch.zeros( - 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) - dummy_image_features = mm_projector(dummy_image_features) + image_features = self.lang_encoder.model.mm_projector(image_features) new_input_embeds = [] - cur_image_idx = 0 - for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): - if (cur_input_ids != self.im_patch_token).all(): + new_labels = [] if labels is not None else None + new_attn_mask = [] if attention_mask is not None else None + for batch_idx, cur_input_ids in enumerate(input_ids): + cur_img = image_features[batch_idx] + + if (cur_input_ids != self.im_token_index).all(): # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = cur_input_embeds + ( - 0. * dummy_image_features).sum() - new_input_embeds.append(cur_input_embeds) - cur_image_idx += 1 + new_input_embeds.append(self.embed_tokens(cur_input_ids)) + if labels is not None: + new_labels.append(labels[batch_idx]) + if attention_mask is not None: + new_attn_mask.append(attention_mask[batch_idx]) continue + + img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0] if self.use_im_start_end: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == self.im_start_token).sum() != ( - cur_input_ids == self.im_end_token).sum(): - raise ValueError('The number of image start tokens and ' - 'image end tokens should be the same.') - image_start_tokens = torch.where( - cur_input_ids == self.im_start_token)[0] - for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to( - device=cur_input_embeds.device) - num_patches = cur_image_features.shape[0] - if cur_input_ids[image_start_token_pos + num_patches + - 1] != self.im_end_token: - raise ValueError('The image end token should follow ' - 'the image start token.') - cur_new_input_embeds = torch.cat( - (cur_input_embeds[:image_start_token_pos + 1], - cur_image_features, - cur_input_embeds[image_start_token_pos + num_patches + - 1:]), - dim=0) - cur_image_idx += 1 - new_input_embeds.append(cur_new_input_embeds) + cur_new_input_embeds = torch.cat( + [ + self.embed_tokens(cur_input_ids[:img_idx - 1]), + self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]), + cur_img, + self.embed_tokens(cur_input_ids[img_idx + 1:]), + self.embed_tokens( + cur_input_ids[img_idx + 1:img_idx + 2]), + self.embed_tokens(cur_input_ids[img_idx + 2:]), + ], + dim=0, + ) else: - cur_image_features = image_features[cur_image_idx] - num_patches = cur_image_features.shape[0] - if (cur_input_ids == self.im_patch_token).sum() != num_patches: - print(f'Debug: num_patches: {num_patches}') - raise ValueError( - 'The number of image patch tokens should ' - 'be the same as the number of image patches.') - masked_indices = torch.where( - cur_input_ids == self.im_patch_token)[0] - mask_index_start = masked_indices[0] - if (masked_indices != torch.arange( - mask_index_start, - mask_index_start + num_patches, - device=masked_indices.device, - dtype=masked_indices.dtype)).any(): - raise ValueError( - 'The image patch tokens should be consecutive.') cur_new_input_embeds = torch.cat( - (cur_input_embeds[:mask_index_start], cur_image_features, - cur_input_embeds[mask_index_start + num_patches:]), - dim=0) - new_input_embeds.append(cur_new_input_embeds) - cur_image_idx += 1 + [ + self.embed_tokens(cur_input_ids[:img_idx]), + cur_img, + self.embed_tokens(cur_input_ids[img_idx + 1:]), + ], + dim=0, + ) + new_input_embeds.append(cur_new_input_embeds) + + if labels is not None: + cur_new_labels = torch.cat([ + labels[batch_idx, :img_idx], + labels.new_full((cur_img.size(0), ), -100), + labels[batch_idx, img_idx+1:], + ], dim=0) + new_labels.append(cur_new_labels) + + if attention_mask is not None: + cur_attn_mask = torch.cat([ + attention_mask[batch_idx, :img_idx], + attention_mask.new_full((cur_img.size(0), ), True), + attention_mask[batch_idx, img_idx+1:], + ], dim=0) + new_attn_mask.append(cur_attn_mask) + inputs_embeds = torch.stack(new_input_embeds, dim=0) + if labels is not None: + labels = torch.stack(new_labels, dim=0) + if attention_mask is not None: + attention_mask = torch.stack(new_attn_mask, dim=0) - return inputs_embeds + return None, attention_mask, past_key_values, inputs_embeds, labels @staticmethod def _reorder_cache(past_key_values, beam_idx): @@ -236,3 +221,6 @@ def _reorder_cache(past_key_values, beam_idx): past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past + + def embed_tokens(self, input_ids): + return self.lang_encoder.model.embed_tokens(input_ids)