diff --git a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py index 569513df..74027c6c 100644 --- a/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +++ b/internvl_chat/internvl/model/internlm2/modeling_internlm2.py @@ -360,6 +360,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if 'padding_mask' in kwargs: @@ -456,6 +457,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # InternLM2FlashAttention2 attention does not support output_attentions @@ -510,7 +512,7 @@ def forward( value_states = value_states.transpose(1, 2) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len + query_states, key_states, value_states, attention_mask, q_len, cu_seqlens=cu_seqlens ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.wo(attn_output) @@ -521,7 +523,7 @@ def forward( return attn_output, attn_weights, past_key_value def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, cu_seqlens=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -544,7 +546,31 @@ def _flash_attention_forward( """ # Contains at least one padding token in the sequence causal = self.is_causal and query_length != 1 - if attention_mask is not None: + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.to(query_states.device).to(torch.int32).view(-1) + cu_seqlens_offset = torch.zeros_like(cu_seqlens) + cu_seqlens_offset[:-1] = cu_seqlens[1:] + max_seqlen = max(cu_seqlens_offset[:-1] - cu_seqlens[:-1]).item() + + _, _, q_heads, head_dim = query_states.shape + _, _, k_heads, head_dim = key_states.shape + query_states = query_states.view(-1, q_heads, head_dim) + key_states = key_states.view(-1, k_heads, head_dim) + value_states = value_states.view(-1, k_heads, head_dim) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + elif attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( query_states, key_states, value_states, attention_mask, query_length @@ -640,6 +666,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -674,6 +701,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cu_seqlens=cu_seqlens, **kwargs, ) hidden_states = residual + hidden_states @@ -876,6 +904,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cu_seqlens: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -952,7 +981,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None) + return module(*inputs, output_attentions, None, cu_seqlens=cu_seqlens) return custom_forward @@ -971,6 +1000,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cu_seqlens=cu_seqlens ) hidden_states = layer_outputs[0] @@ -1045,6 +1075,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cu_seqlens: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1089,6 +1120,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cu_seqlens=cu_seqlens ) hidden_states = outputs[0] diff --git a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py index 98009fcc..d459c40f 100644 --- a/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +++ b/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -145,6 +145,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cu_seqlens: Optional[torch.LongTensor] = None ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -185,6 +186,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cu_seqlens=cu_seqlens ) logits = outputs.logits diff --git a/internvl_chat/internvl/train/fast_dataset.py b/internvl_chat/internvl/train/fast_dataset.py new file mode 100644 index 00000000..471d0110 --- /dev/null +++ b/internvl_chat/internvl/train/fast_dataset.py @@ -0,0 +1,337 @@ +import json + +import numpy as np +from multiprocessing.pool import ThreadPool as Pool + +import os +import torch +from torch.utils.data import Dataset + + +def get_token_sum(g): + sum = 0 + for i in g: + sum += i[2] + return sum + + +def get_vit_num(g): + vit_num = 0 + for _ in g: + vit_num += _[1] + return vit_num + + +DEFAULT_SEED = 1024 +class BalancedDataset(Dataset): + def __init__(self, + dataset=None, + tokenizer=None, + vit_packed_length=15, + llm_packed_length=4096, + llm_thresh={}, + worker=64, + iter_time=100): + assert dataset is not None + self.dataset = dataset + self.tokenizer = tokenizer + self.vit_packed_length = vit_packed_length + self.llm_packed_length = llm_packed_length + self.llm_thresh = llm_thresh + + self.vit_lengths, self.llm_lengths = [], [] + self.worker = worker + self.pad_token_id = len(self.tokenizer) - 1 + self.iter_time = iter_time + + print("Begin preprocess dataset", flush=True) + self.preprocess() + print("Preprocess dataset successed", flush=True) + self.seed = DEFAULT_SEED + self.pack_groups = self.get_packed_groups() + + def preprocess(self): + dict_num_tokens = {} + num_datasets = len(self.dataset.datasets) + for dataset_idx in range(num_datasets): + sub_dataset = self.dataset.datasets[dataset_idx] + if "token_lengths" in sub_dataset.meta: + print(f"Load from cache for dataset {dataset_idx}", flush=True) + assert os.path.exists(sub_dataset.meta["token_lengths"]), f"Dataset {dataset_idx} token_lengths file does not exist." + with open(sub_dataset.meta["token_lengths"], "r") as f: + token_lengths = json.load(f) + dict_num_tokens[dataset_idx] = { + "lengths": len(sub_dataset), + "token_lengths": token_lengths # sub_dataset.meta["token_lengths"] + } + else: + print(f"Generate length json for dataset {dataset_idx}", flush=True) + token_lengths = [] + origin_indexs = list(range(len(sub_dataset))) + token_lengths_dict = dict() + + def decode_text(idx): + meta = sub_dataset.__getitem__(idx) + token_lengths_dict[idx] = { + "vit_num": meta['pixel_values'].shape[0], + "token_num": len(meta['input_ids']), + "image_flags": meta["image_flags"].sum().item() + } + + with Pool(self.worker) as p: + _ = p.map(decode_text, origin_indexs[:]) + for idx in range(len(sub_dataset)): + token_lengths.append( + token_lengths_dict[idx] + ) + dict_num_tokens[dataset_idx] = { + "lengths": len(sub_dataset), + "token_lengths": token_lengths + } + print(f"Finish length json for dataset {dataset_idx}", flush=True) + self.dict_num_tokens = dict_num_tokens + + def _random_groups(self, token_lengths, seed=None): + """ + tokens_length: [(idx, vit_img_num, llm_token_len)] + """ + rng = np.random.RandomState(seed) + index = list(range(len(token_lengths))) + rng.shuffle(index) + + pack_groups = [] + vit_token_length_sum, llm_token_length_sum = 0, 0 + each_group = [] + for idx, sample_id in enumerate(index): + vit_sample_length, llm_sample_length = token_lengths[sample_id][1], token_lengths[sample_id][2] + if vit_sample_length > self.vit_packed_length or llm_sample_length > self.llm_packed_length: + continue + vit_token_length_sum += vit_sample_length + llm_token_length_sum += llm_sample_length + if vit_token_length_sum > self.vit_packed_length or llm_token_length_sum > self.llm_packed_length: + pack_groups.append(each_group) + vit_token_length_sum = vit_sample_length + llm_token_length_sum = llm_sample_length + each_group = [token_lengths[sample_id]] + else: + each_group.append(token_lengths[sample_id]) + if idx == len(token_lengths) - 1: + if len(each_group) > 0: + pack_groups.append(each_group) + return pack_groups + + def process_random_groups_input(self, groups, accu_length=0): + new_groups = [] + for idx, item in enumerate(groups): + if item["vit_num"] == -1: + print(f"item {idx} was filted.", flush=True) + continue + new_groups.append((idx + accu_length, item['vit_num'], item['token_num'])) + return new_groups + + def iter_random_groups(self, groups, llm_thresh=None, seed=None, iter_time=300): + if llm_thresh is None: + llm_thresh = self.llm_packed_length + if seed is None: + seed = self.seed + groups = self._random_groups(groups, seed=seed) + if iter_time == 1: + return groups + output = [] + for i in range(iter_time - 1): + print(f"iter_random_groups {i} / {iter_time - 1}", flush=True) + need_process_groups = [] + for g in groups: + vit_num = get_vit_num(g) + llm_num = get_token_sum(g) + if vit_num == self.vit_packed_length or llm_num >= llm_thresh: + output.append(g) + else: + need_process_groups.extend(g) + if len(need_process_groups) >= 0: + groups = self._random_groups(need_process_groups, seed + i) + else: + break + if len(need_process_groups) > 0: + output.extend(self._random_groups(need_process_groups, seed + i)) + return output + + def collect_packed_info(self, packed_groups): + info_dict = {} + info_dict['vit_num_info'] = {} + vit_num_min = 10000000 + vit_num_max = 0 + llm_num_min = 10000000 + llm_num_max = 0 + vit_ave_num = 0 + llm_ave_num = 0 + sample_num = 0 + for group in packed_groups: + vit_num = get_vit_num(group) + llm_num = get_token_sum(group) + if vit_num not in info_dict['vit_num_info']: + info_dict['vit_num_info'][vit_num] = 0 + info_dict['vit_num_info'][vit_num] += 1 + vit_num_min = min(vit_num_min, vit_num) + vit_num_max = max(vit_num_max, vit_num) + llm_num_min = min(llm_num_min, llm_num) + llm_num_max = max(llm_num_max, llm_num) + vit_ave_num += vit_num + llm_ave_num += llm_num + sample_num += len(group) + info_dict['vit_num_min'] = vit_num_min + info_dict['vit_num_max'] = vit_num_max + info_dict['vit_ave_num'] = vit_ave_num / float(len(packed_groups)) + info_dict['llm_ave_num'] = llm_ave_num / float(len(packed_groups)) + info_dict['sample_num'] = sample_num + info_dict['packed_group_num'] = len(packed_groups) + return info_dict + + def find_best_groups(self, input_groups, step=4, step_num=20): + best_group_num = 10000000000000 + best_groups = [] + best_info_dict = {} + best_llm_thresh = 0 + llm_thresh = self.llm_packed_length + for step_id in range(step_num): + print(f"find_best_groups {step_id} / {step_num}", flush=True) + groups = self.iter_random_groups(input_groups, llm_thresh, seed=self.seed, iter_time=self.iter_time) + cur_info_dict = self.collect_packed_info(groups) + if cur_info_dict['packed_group_num'] < best_group_num: + best_group_num = cur_info_dict['packed_group_num'] + best_groups = groups + best_info_dict = cur_info_dict + best_llm_thresh = llm_thresh + llm_thresh -= step + print(f"llm thresh {best_llm_thresh} best info dict", best_info_dict, flush=True) + return best_groups + + def get_packed_groups(self): + num_datasets = len(list(self.dict_num_tokens.keys())) + accu_length = 0 + input_groups = [] + for d_idx in range(num_datasets): + dict_item = self.dict_num_tokens[d_idx] + token_lengths = dict_item["token_lengths"] + groups = self.process_random_groups_input(token_lengths, accu_length) + print(f"get_packed_groups {d_idx}.", flush=True) + input_groups.extend(groups) + accu_length += len(token_lengths) + if self.llm_thresh.get('thresh', None) is not None: + groups = self.iter_random_groups(input_groups, llm_thresh=self.llm_thresh['thresh'], seed=self.seed, iter_time=self.iter_time) + else: + groups = self.find_best_groups(input_groups, self.llm_thresh.get('step', 4), self.llm_thresh.get('step_num', 10)) + print(self.collect_packed_info(groups), flush=True) + print("get_packed_groups done!", flush=True) + return groups + + def __getitem__(self, item: int): + item = item % len(self.pack_groups) + # item = random.randint(0, len(self.pack_groups) - 1) + while True: + try: + groups = self.pack_groups[item] + + input_ids, pixel_values = [], [] + labels, position_ids, image_flags = [], [], [] + cu_seqlens = [0] + for g in groups: + idx, num_patches, llm_length = g + meta = self.dataset.__getitem__(idx) + # print("llm_length: ", llm_length, "input_ids: ", len(meta["input_ids"])) + assert len(meta["input_ids"]) == llm_length + assert meta["pixel_values"].size(0) == num_patches + input_ids.append(meta['input_ids']) + pixel_values.append(meta['pixel_values']) + labels.append(meta['labels']) + cu_seqlens.append(len(meta['input_ids'])) + position_ids.extend(list(range(len(meta['input_ids'])))) + image_flags.append(meta.get('image_flags', torch.tensor([0], dtype=torch.long))) + + cu_seqlens = np.cumsum(np.array(cu_seqlens)).tolist() + input_ids = torch.cat(input_ids)[:self.llm_packed_length] + pixel_values = torch.cat(pixel_values)[:self.vit_packed_length] + labels = torch.cat(labels)[:self.llm_packed_length] + cu_seqlens = torch.clamp(torch.LongTensor(cu_seqlens), max=self.llm_packed_length) + position_ids = torch.LongTensor(position_ids)[:self.llm_packed_length] + image_flags = torch.cat(image_flags) + if len(image_flags) == 0: # pure llm text + image_flags = torch.tensor([0], dtype=torch.long) + + ret = { + "input_ids": input_ids, + "labels": labels, + "cu_seqlens": cu_seqlens, + "position_ids": position_ids, + "pixel_values": pixel_values, + "image_flags": image_flags + } + break + except Exception as e: + print(f"{e}", flush=True) + # i = random.randint(0, len(self.raw_data) - 1) + item = (item + 100) % len(self.pack_groups) + return ret + + def __len__(self): + n_packs = len(self.pack_groups) + return n_packs + + +IGNORE_INDEX = -100 +def fast_concat_pad_data_collator(features, pad_id=0): + + first = features[0] + batch = {} + + batch_lens = [feat['input_ids'].shape for feat in features] + max_item_length = max(batch_lens)[0] + for idx in range(len(features)): + feat = features[idx] + temp_input_ids = torch.LongTensor([pad_id] * max_item_length) + temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] + feat['input_ids'] = temp_input_ids + temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) + temp_labels[:feat['labels'].shape[0]] = feat['labels'] + feat['labels'] = temp_labels + feat['attention_mask'] = feat['input_ids'].ne(pad_id) + if "position_ids" in feat: + temp_position_ids = torch.LongTensor([0] * max_item_length) + temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] + feat['position_ids'] = temp_position_ids + if "cu_seqlens" in feat: + feat['cu_seqlens'][-1] = feat['position_ids'].size(0) + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if 'label' in first and first['label'] is not None: + label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] + dtype = torch.long if isinstance(label, int) else torch.float + batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) + elif 'label_ids' in first and first['label_ids'] is not None: + if isinstance(first['label_ids'], torch.Tensor): + batch['labels'] = torch.stack([f['label_ids'] for f in features]) + else: + dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float + batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in first.items(): + if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ + v is not None and not isinstance(v, str): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + if k in ('pixel_values', 'image_flags'): + if isinstance(v, torch.Tensor): + batch[k] = torch.concat([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.concat(np.stack([f[k] for f in features])) + else: + batch[k] = torch.concat([f[k] for f in features]) + return batch \ No newline at end of file diff --git a/internvl_chat/internvl/train/internvl_chat_finetune.py b/internvl_chat/internvl/train/internvl_chat_finetune.py index 2ace5ea1..6b783ab2 100644 --- a/internvl_chat/internvl/train/internvl_chat_finetune.py +++ b/internvl_chat/internvl/train/internvl_chat_finetune.py @@ -43,6 +43,7 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils.logging import (enable_default_handler, enable_explicit_format, set_verbosity) +from internvl.train.fast_dataset import BalancedDataset, fast_concat_pad_data_collator # Apply necessary patches for the transformers library replace_llama_rmsnorm_with_fused_rmsnorm() @@ -141,6 +142,10 @@ class ModelArguments: metadata={'help': 'Specify the version of pixel shuffle implementation. Default is `v1`.' 'Please use `v2` to fix the bug of transposed image.'} ) + tokenizer_path: Optional[str] = field( + default=None, + metadata={'help': 'Path to tokenizer'} + ) @dataclass @@ -200,6 +205,26 @@ class DataTrainingArguments: default='imagenet', metadata={'help': 'The normalize type for the image. Default is imagenet.'}, ) + use_fast_dataset: Optional[bool] = field( + default=False, + metadata={'help': 'Set to True to use fast dataset.'}, + ) + vit_packed_length: Optional[int] = field( + default=9, + metadata={'help': 'The value for vit packed length. Default is 9.'}, + ) + llm_packed_length: Optional[int] = field( + default=4096, + metadata={'help': 'The value for llm packed length. Default is 4096.'}, + ) + llm_thresh: Optional[int] = field( + default=4068, + metadata={'help': 'The value for llm thresh. Default is 4068.'}, + ) + iter_time: Optional[int] = field( + default=100, + metadata={'help': 'The value for iter_time. Default is 100.'}, + ) class LazySupervisedDataset(Dataset): @@ -227,6 +252,7 @@ def __init__( repeat_time=1, normalize_type='imagenet', random_seed=0, + use_fast_dataset=False, ): super(LazySupervisedDataset, self).__init__() self.ds_name = ds_name @@ -271,6 +297,8 @@ def __init__( self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch self.normalize_type = normalize_type + self.meta = meta + self.use_fast_dataset = use_fast_dataset # If the precomputed length does not exist, roughly estimate the length of # each sample to improve the efficiency of group_by_length. @@ -363,9 +391,10 @@ def multi_modal_get_item(self, data_item): preprocess_function = self.get_preprocess_function() # Preprocess the conversations and generate the return dictionary + group_by_length = True if self.use_fast_dataset else self.group_by_length ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], self.tokenizer, [self.num_image_token * num_patches], - group_by_length=self.group_by_length, ds_name=self.ds_name) + group_by_length=group_by_length, ds_name=self.ds_name) # Create the final return dictionary ret = dict( @@ -494,9 +523,10 @@ def pure_text_get_item(self, data_item): preprocess_function = self.get_preprocess_function() # Preprocess the conversations and generate the return dictionary + group_by_length = True if self.use_fast_dataset else self.group_by_length ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], self.tokenizer, [self.num_image_token * num_patches], text_only=True, - group_by_length=self.group_by_length, ds_name=self.ds_name) + group_by_length=group_by_length, ds_name=self.ds_name) # Create the final return dictionary ret = dict( @@ -584,6 +614,7 @@ def build_datasets( repeat_time=repeat_time, normalize_type=normalize_type, random_seed=ds_idx, + use_fast_dataset=data_args.use_fast_dataset ) logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}') datasets.append(dataset) @@ -660,7 +691,10 @@ def main(): set_seed(training_args.seed) # Load pretrained model, tokenizer, and image processor - tokenizer_path = model_args.model_name_or_path or model_args.llm_path + if model_args.tokenizer_path is not None: + tokenizer_path = model_args.tokenizer_path + else: + tokenizer_path = model_args.model_name_or_path or model_args.llm_path logger.info(f'Loading Tokenizer: {tokenizer_path}') tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=False) @@ -768,6 +802,19 @@ def main(): min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch, normalize_type=data_args.normalize_type) + if data_args.use_fast_dataset: + train_dataset = BalancedDataset( + dataset=train_dataset, + tokenizer=deepcopy(tokenizer), + vit_packed_length=data_args.vit_packed_length, # 20, # 14,8 + llm_packed_length=data_args.llm_packed_length, # 8192, # 6144 4096, + iter_time=data_args.iter_time, + llm_thresh={"thresh": data_args.llm_thresh}, # 8064 6016 4068 + ) + concat_pad_data_collator_func = fast_concat_pad_data_collator + else: + concat_pad_data_collator_func = concat_pad_data_collator + def _freeze_params(module): for param in module.parameters(): param.requires_grad = False @@ -819,7 +866,7 @@ def _freeze_params(module): train_dataset=train_dataset if training_args.do_train else None, eval_dataset=None, tokenizer=tokenizer, - data_collator=concat_pad_data_collator + data_collator=concat_pad_data_collator_func ) # Training diff --git a/internvl_chat/tools/data_preprocess_stastics.py b/internvl_chat/tools/data_preprocess_stastics.py new file mode 100644 index 00000000..9c94c841 --- /dev/null +++ b/internvl_chat/tools/data_preprocess_stastics.py @@ -0,0 +1,375 @@ +import json +from multiprocessing import Manager +import multiprocessing +import argparse +from tqdm import tqdm +from functools import partial +import os +import numpy as np +from copy import deepcopy +import torch + +from transformers import AutoTokenizer +from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN, + IMG_CONTEXT_TOKEN, IMG_END_TOKEN, + IMG_START_TOKEN, QUAD_END_TOKEN, + QUAD_START_TOKEN, REF_END_TOKEN, + REF_START_TOKEN) +from internvl.train.dataset import find_closest_aspect_ratio +from internvl.conversation import get_conv_template +from torch.utils.data import Dataset +PROCESSES = 64 + + +def get_num_patchs(orig_width, orig_height, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + # target_width = image_size * target_aspect_ratio[0] + # target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + if use_thumbnail and blocks != 1: + blocks += 1 + return blocks + + +def preprocess_internlm( + template_name, + sources, + tokenizer, + num_image_token, + text_only=False, + group_by_length=False, + num_image=1 +): + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + sentence['value'] = sentence['value'].strip() + if sentence['value'][0] == '\n': + sentence['value'] = sentence['value'][1:] + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token}{IMG_END_TOKEN}' + if not text_only: + new_conversations = [] + for conversation in conversations: + conversation = conversation.replace('', image_tokens, num_image) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False, + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + return len(input_ids[0]) + + +def preprocess_internlm_v2( + template_name, + sources, + tokenizer, + num_image_token, + text_only=False, + group_by_length=False, + num_image=1 +): + num_image_token_list = [num_image_token] + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + sentence['value'] = sentence['value'].strip() + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False, + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + return len(input_ids[0]) + + +class DataProcess(Dataset): + def __init__(self, template_name, meta, tokenizer, num_image_token, image_size=224, dynamic_image_size=False, + use_thumbnail=False, min_dynamic_patch=1, max_dynamic_patch=6, repeat_time=1, is_train=False, + pad2square=False, group_by_length=False, read_img=False, random_seed=0): + super(DataProcess, self).__init__() + self.template_name = template_name + self.meta = meta + self.tokenizer = tokenizer + self.num_image_token = num_image_token + self.group_by_length = group_by_length + self.image_size = image_size + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + with open(meta['annotation'], 'r') as f: + self.raw_data = f.readlines() + if repeat_time < 1: + # choice top len(self.raw_data) * repeat_time samples + self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)] + # for v2 + self.rng = np.random.default_rng(seed=random_seed) + self.rng.shuffle(self.raw_data) + def __len__(self): + return len(self.raw_data) + + def multi_modal_get_item(self, data_item): + if '' not in data_item['conversations'][0]['value']: + data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] + + orig_width, orig_height = data_item["width"], data_item["height"] + num_patches = get_num_patchs(orig_width, orig_height, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, + image_size=self.image_size, use_thumbnail=self.use_thumbnail) + + # if not self.dynamic_image_size: + # assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.' + # if self.template_name == 'Hermes-2': + # preprocess_function = preprocess_mpt + # elif self.template_name == 'internlm2-chat': + # preprocess_function = preprocess_internlm + preprocess_function = preprocess_internlm_v2 + # else: + # preprocess_function = preprocess + num_tokens = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], + self.tokenizer, self.num_image_token * num_patches, + group_by_length=True) + + ret = dict( + num_patches=num_patches, + num_tokens=num_tokens, + image_flags=torch.tensor([1] * num_patches, dtype=torch.long) + ) + return ret + + def pure_text_get_item(self, data_item): + num_patches = 1 + # preprocess_function = preprocess_internlm + preprocess_function = preprocess_internlm_v2 + num_tokens = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], + self.tokenizer, self.num_image_token * num_patches, + group_by_length=True) + + ret = dict( + num_patches=num_patches, + num_tokens=num_tokens, + image_flags=torch.tensor([0] * num_patches, dtype=torch.long) + ) + return ret + + def __getitem__(self, idx): + idx = idx % len(self.raw_data) + data_item = json.loads(self.raw_data[idx]) + if 'image' in data_item and data_item['image'] is not None and len(data_item['image']) != 0: + ret = self.multi_modal_get_item(data_item) + else: + ret = self.pure_text_get_item(data_item) + return ret + + +def decode_text(args): + cfg_dataset, inds = args + dataset = DataProcess(**cfg_dataset) + dataset.ds_name = "dummy" + token_lengths = [] + for idx in inds: + item = dataset.__getitem__(idx) + flag = item['image_flags'].sum().item() + if flag == 0: + num_vit_patch = item['num_patches'] + num_token = item['num_tokens'] + image_flags = 0 + elif flag == -1: + num_vit_patch = -1 + num_token = -1 + image_flags = -1 + else: + num_vit_patch = flag + num_token = item['num_tokens'] + image_flags = flag + + token_lengths.append( + { + "vit_num": num_vit_patch, + "token_num": num_token, + "image_flags": image_flags + } + ) + + return token_lengths + + +import copy +def worker(cfg_dataset, ds_name, token_lengths_path, ds_info): + dataset = DataProcess(**cfg_dataset) + with multiprocessing.Pool(PROCESSES) as pool: + token_lengths_all = pool.map(decode_text, [(cfg_dataset, inds) for inds in np.array_split(range(len(dataset)), PROCESSES)]) + l_token_lengths = [] + for tmp in token_lengths_all: + l_token_lengths.extend(tmp) + + length_save_path = os.path.join(token_lengths_path, f"{ds_name}"+"_token_lengths.json") + + with open(length_save_path, "w") as f: + json.dump(l_token_lengths, f, indent=4) + if "max_dynamic_patch" in ds_info: + info = { + "root": ds_info["root"], + "annotation": ds_info["annotation"], + "data_augment": ds_info["data_augment"], + "repeat_time": ds_info["repeat_time"], + "length": len(dataset), + "token_lengths": length_save_path, + "max_dynamic_patch": ds_info["max_dynamic_patch"] + } + else: + info = { + "root": ds_info["root"], + "annotation": ds_info["annotation"], + "data_augment": ds_info["data_augment"], + "repeat_time": ds_info["repeat_time"], + "length": len(dataset), + "token_lengths": length_save_path + } + return info + + +from tqdm import tqdm +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + default=None, + help="data root path", + ) + parser.add_argument( + "--json_file", + default=None, + help="json file to statistics" + ) + parser.add_argument( + "--worker", + default=64, type=int, + help="worker num", + ) + parser.add_argument( + "--token_lengths_path", + default=None, + help="token_lengths_path", + ) + parser.add_argument( + "--output_path", + default=None, + help="token_lengths_path", + ) + args = parser.parse_args() + + token_lengths_path = args.token_lengths_path + + # setting + model_max_length = 4096 + tokenizer_path = "/path/to/tokenizer" + data_path = args.json_file + + cfg_dataset_base = { + 'template_name': 'internlm2-chat', + 'num_image_token': 256, + 'image_size': 448, + 'dynamic_image_size': True, + 'use_thumbnail': True, + 'min_dynamic_patch': 1, + 'max_dynamic_patch': 4, + 'pad2square': False + } + + # build tokenizer + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=False) + tokenizer.tokenizer_path = tokenizer_path + tokenizer.model_max_length = model_max_length + token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, + QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN, + REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN] + num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True) + + cfg_dataset_base['tokenizer'] = tokenizer + + ds_collections = json.loads(open(data_path).read()) + import time + t_1 = time.time() + meta = {} + idx = 0 + for ds_name in tqdm(ds_collections.keys()): + print(ds_name) + cfg_dataset = copy.deepcopy(cfg_dataset_base) + cfg_dataset['meta'] = ds_collections[ds_name] + cfg_dataset['random_seed'] = idx + ds_info = {} + ds_info["root"] = ds_collections[ds_name]["root"] + ds_info["annotation"] = ds_collections[ds_name]["annotation"] + ds_info["data_augment"] = ds_collections[ds_name].get("data_augment", False) + ds_info["repeat_time"] = ds_collections[ds_name]['repeat_time'] + if 'max_dynamic_patch' in ds_collections[ds_name]: + ds_info['max_dynamic_patch'] = ds_collections[ds_name]['max_dynamic_patch'] + + meta[ds_name] = worker(cfg_dataset, ds_name, token_lengths_path, ds_info) + idx += 1 + + with open(args.output_path, "w") as f: + json.dump(meta.copy(), f, indent=4) + + t_2 = time.time() + print(f"time: {t_2-t_1}") diff --git a/internvl_chat/tools/data_preprocess_stastics.sh b/internvl_chat/tools/data_preprocess_stastics.sh new file mode 100644 index 00000000..997a361c --- /dev/null +++ b/internvl_chat/tools/data_preprocess_stastics.sh @@ -0,0 +1,8 @@ +ROOT=/path/to/InternVL/internvl_chat +export PYTHONPATH=$ROOT:$PYTHONPATH + +export OMP_NUM_THREADS=1 + + +python data_preprocess_stastics.py --json_file $1 --token_lengths_path $2 --output_path $3 2>&1 | tee -a log_statistics.txt +