From e273472cd1ba1f19aa2ad491823ec1d6932eb4af Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Fri, 28 Jun 2024 10:23:17 +0800 Subject: [PATCH] enable llava & Qwen-VL multimodal model quantization Signed-off-by: Zhang, Weiwei1 --- auto_round/autoround.py | 91 ++-- auto_round/export/export_to_autogptq.py | 29 +- auto_round/special_model_handler.py | 11 + auto_round/utils.py | 34 +- examples/multimodal-modeling/Llava/README.md | 94 ++++ .../multimodal-modeling/Llava/evaluation.py | 309 +++++++++++ examples/multimodal-modeling/Llava/main.py | 371 ++++++++++++++ .../Llava/mm_evaluation/__init__.py | 1 + .../Llava/mm_evaluation/textvqa.py | 201 ++++++++ .../Llava/run_autoround.sh | 18 + examples/multimodal-modeling/Qwen-VL/main.py | 484 ++++++++++++++++++ .../Qwen-VL/mm_evaluation/__init__.py | 4 + .../Qwen-VL/mm_evaluation/evaluate_vqa.py | 421 +++++++++++++++ .../Qwen-VL/mm_evaluation/vqa.py | 206 ++++++++ .../Qwen-VL/mm_evaluation/vqa_eval.py | 330 ++++++++++++ examples/multimodal-modeling/requirements.txt | 18 + examples/multimodal-modeling/run_autoround.sh | 15 + 17 files changed, 2578 insertions(+), 59 deletions(-) create mode 100644 examples/multimodal-modeling/Llava/README.md create mode 100644 examples/multimodal-modeling/Llava/evaluation.py create mode 100644 examples/multimodal-modeling/Llava/main.py create mode 100644 examples/multimodal-modeling/Llava/mm_evaluation/__init__.py create mode 100644 examples/multimodal-modeling/Llava/mm_evaluation/textvqa.py create mode 100644 examples/multimodal-modeling/Llava/run_autoround.sh create mode 100644 examples/multimodal-modeling/Qwen-VL/main.py create mode 100644 examples/multimodal-modeling/Qwen-VL/mm_evaluation/__init__.py create mode 100644 examples/multimodal-modeling/Qwen-VL/mm_evaluation/evaluate_vqa.py create mode 100644 examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa.py create mode 100644 examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa_eval.py create mode 100644 examples/multimodal-modeling/requirements.txt create mode 100644 examples/multimodal-modeling/run_autoround.sh diff --git a/auto_round/autoround.py b/auto_round/autoround.py index ff388911..d96c8249 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -23,7 +23,7 @@ from .calib_dataset import get_dataloader from .quantizer import WrapperMultiblock, wrapper_block, unwrapper_block, WrapperLinear, unwrapper_layer -from .special_model_handler import check_hidden_state_dim, check_share_attention_mask +from .special_model_handler import check_hidden_state_dim, check_share_attention_mask, check_not_share_position_ids from .utils import ( CpuInfo, block_forward, @@ -89,6 +89,7 @@ class AutoRound(object): data_type (str): The data type to be used (default is "int"). scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels have different choices. + multimodal(bool): Enable multimodal model quantization, (default is "False"). Returns: The quantized model. @@ -124,6 +125,7 @@ def __init__( dynamic_max_gap: int = -1, data_type: str = "int", ##only support int for now scale_dtype: str = "fp16", + multimodal:bool = False, **kwargs, ): self.quantized = False @@ -153,6 +155,7 @@ def __init__( logger.info(f"using {self.model.dtype} for quantization tuning") self.dataset = dataset self.iters = iters + self.multimodal = multimodal if self.iters <= 0: logger.warning("iters must be positive, reset it to 200") self.iters = 200 @@ -203,8 +206,8 @@ def quantize(self): The quantized model and weight configurations. """ # logger.info("cache block input") - block_names = get_block_names(self.model) - if len(block_names) == 0: + all_blocks = get_block_names(self.model) + if len(all_blocks) == 0: logger.warning("could not find blocks, exit with original model") return self.model, self.weight_config @@ -213,29 +216,28 @@ def quantize(self): layer_names = self.get_quantized_layer_names_outside_blocks() self.start_time = time.time() - all_inputs = self.try_cache_inter_data_gpucpu([block_names[0]], self.nsamples, layer_names=layer_names) - del self.inputs - inputs = all_inputs[block_names[0]] - - all_inputs.pop(block_names[0]) - self.inputs = None - del self.inputs - if "input_ids" in inputs.keys(): - total_samples = len(inputs["input_ids"]) - self.nsamples = total_samples - if total_samples < self.train_bs: - self.train_bs = total_samples - logger.warning(f"force the train batch size to {total_samples} ") - - self.model = self.model.to("cpu") - torch.cuda.empty_cache() - self.quant_blocks( - self.model, - inputs, - block_names, - nblocks=self.nblocks, - device=self.device, - ) + all_first_block_names = [block[0] for block in all_blocks] + all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) + for block_names in all_blocks: + inputs = all_inputs[block_names[0]] + all_inputs.pop(block_names[0]) + self.inputs = None + del self.inputs + if "input_ids" in inputs.keys(): + total_samples = len(inputs["input_ids"]) + self.n_samples = total_samples + if total_samples < self.train_bs: + self.train_bs = total_samples + logger.warning(f"force the train batch size to {total_samples} ") + self.model = self.model.to("cpu") + torch.cuda.empty_cache() + self.quant_blocks( + self.model, + inputs, + block_names, + nblocks=self.nblocks, + device=self.device, + ) self.quant_layers(layer_names, all_inputs) @@ -338,11 +340,11 @@ def set_layerwise_config(self, weight_config): Returns: None """ - layers_inblocks = get_layer_names_in_block(self.model, self.supported_types) + layers_in_blocks = get_layer_names_in_block(self.model, self.supported_types) for n, m in self.model.named_modules(): if not isinstance(m, tuple(self.supported_types)): continue - if n not in weight_config.keys() and n in layers_inblocks: + if n not in weight_config.keys() and n in layers_in_blocks: weight_config[n] = {} weight_config[n]["data_type"] = self.data_type weight_config[n]["bits"] = self.bits @@ -396,7 +398,13 @@ def get_block_outputs(self, block, input_ids, input_others, bs, device, cache_de end_index = min(self.nsamples, i + bs) indices = torch.arange(i, end_index).to(torch.long) tmp_input_ids, tmp_input_others = sampling_inputs( - input_ids, input_others, indices, self.seqlen, self.share_attention_mask_flag, self.input_dim + input_ids, + input_others, + indices, + self.seqlen, + self.share_attention_mask_flag, + self.not_share_position_ids_flag, + self.input_dim ) tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to( cache_device @@ -449,6 +457,14 @@ def calib(self, nsamples, bs): for key in data.keys(): data_new[key] = data[key].to(self.model.device) input_ids = data_new["input_ids"] + elif isinstance(data, tuple) or isinstance(data, list): + if self.multimodal: + data_new = {"input_ids": data[0].to(self.model.device), \ + "images": data[1].to(self.model.device, dtype=self.model.dtype), "image_sizes": data[2]} + input_ids = data_new["input_ids"] + else: + data_new = data + input_ids = data_new[0] else: data_new = {} for key in data.keys(): @@ -460,13 +476,15 @@ def calib(self, nsamples, bs): try: if isinstance(data_new, torch.Tensor): self.model(data_new) + elif isinstance(data_new, tuple) or isinstance(data_new, list): + self.model(*data_new) else: self.model(**data_new) except NotImplementedError: pass except Exception as error: logger.error(error) - total_cnt += input_ids.shape[0] + total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 if total_cnt >= nsamples: break if total_cnt == 0: @@ -483,7 +501,7 @@ def calib(self, nsamples, bs): @torch.no_grad() def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=[], last_cache_name=None): - """Attempts to cache intermediate data on GPU,if failed, then using CPU. + """Attempts to cache intermediate data on GPU, if failed, then using CPU. Args: block_names (list): List of block names to cache data for. @@ -542,6 +560,7 @@ def cache_inter_data(self, block_names, nsamples, layer_names=[], last_cache_nam self.last_cache_name = last_cache_name if last_cache_name is None and len(block_names) + len(layer_names) == 1: self.last_cache_name = block_names[0] if len(block_names) == 1 else layer_names[0] + # do not set last_cache_name for multimodal models calib_bs = self.train_bs self.hook_handles = [] self._replace_forward() @@ -579,6 +598,7 @@ def forward(m, hidden_states, *positional_args, **kwargs): if self.share_attention_mask_flag is None: self.input_dim = check_hidden_state_dim(self.model, positional_args) self.share_attention_mask_flag = check_share_attention_mask(self.model, hidden_states, **kwargs) + self.not_share_position_ids_flag = check_not_share_position_ids(self.model, **kwargs) if name in self.inputs: self.inputs[name]["input_ids"].extend(list(torch.split(hidden_states.to("cpu"), 1, dim=self.input_dim))) else: @@ -612,6 +632,13 @@ def forward(m, hidden_states, *positional_args, **kwargs): self.inputs[name][key].extend(list(torch.split(alibi.to("cpu"), 1, dim=0))) else: self.inputs[name][key] = list(torch.split(alibi.to("cpu"), 1, dim=0)) + elif "position_ids" in key: + if key not in self.inputs[name].keys(): + self.inputs[name][key] = list(torch.split(kwargs[key].to("cpu"), 1, dim=0)) \ + if self.not_share_position_ids_flag \ + else to_device(kwargs[key], device=torch.device("cpu")) + elif kwargs[key] is not None and self.not_share_position_ids_flag: + self.inputs[name][key].extend(list(torch.split(kwargs[key].to("cpu"), 1, dim=0))) elif key not in self.inputs[name].keys(): self.inputs[name][key] = to_device(kwargs[key], device=torch.device("cpu")) if name == self.last_cache_name: @@ -847,6 +874,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch indices, seqlen=self.seqlen, share_attention_mask_flag=self.share_attention_mask_flag, + not_share_position_ids_flag=self.not_share_position_ids_flag, input_dim=self.input_dim, ) @@ -1414,3 +1442,4 @@ def __init__( **kwargs, ) + diff --git a/auto_round/export/export_to_autogptq.py b/auto_round/export/export_to_autogptq.py index 67f7b676..244d11d5 100644 --- a/auto_round/export/export_to_autogptq.py +++ b/auto_round/export/export_to_autogptq.py @@ -68,22 +68,23 @@ def save_quantized_as_autogptq(output_dir, use_triton=True, inplace=True, if tokenizer is not None: tokenizer.save_pretrained(output_dir) ##check module quantized in block, this may have bug for mixed precision quantization - block_name = get_block_names(model)[0] - first_block = get_module(model, block_name) + all_blocks = get_block_names(model) all_to_quantized = True modules_in_block_to_quantize = [] - for n, m in first_block.named_modules(): - is_supported_type = False - for supported_type in supported_types: - if isinstance(m, supported_type): - is_supported_type = True - break - if not is_supported_type: - continue - if not check_to_quantized(m): - all_to_quantized = False - else: - modules_in_block_to_quantize.append(n) + for block_names in all_blocks: + first_block = get_module(model, block_names[0]) + for n, m in first_block.named_modules(): + is_supported_type = False + for supported_type in supported_types: + if isinstance(m, supported_type): + is_supported_type = True + break + if not is_supported_type: + continue + if not check_to_quantized(m): + all_to_quantized = False + else: + modules_in_block_to_quantize.append(n) modules_in_block_to_quantize = [modules_in_block_to_quantize] if all_to_quantized: modules_in_block_to_quantize = None diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 82c1aaa1..782c7536 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -16,6 +16,7 @@ share_attention_mask_tuple = ("baichuan",) special_states_dim_tuple = ("chatglm",) +not_share_position_ids_tuple = ("llava",) def check_share_attention_mask(model, hidden_states, attention_mask=None, **kwargs): @@ -54,3 +55,13 @@ def check_hidden_state_dim(model, positional_args): is_special = True break return int(is_special and positional_args is not None) + + +def check_not_share_position_ids(model, **kwargs): + is_special = False + for key in not_share_position_ids_tuple: + if hasattr(model, "config") and key in model.config.model_type: + is_special = True + break + return bool(is_special) + diff --git a/auto_round/utils.py b/auto_round/utils.py index 2b8d817b..cc5b44d3 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -132,7 +132,7 @@ def get_scale_shape(weight, group_size): return shape -def to_device(input, device=torch.device("cpu")): +def to_device(input, device=torch.device("cpu"), multimodal=False): """Moves input data to the specified device. Args: @@ -185,13 +185,15 @@ def get_block_names(model): block_names: A list of block names. """ block_names = [] - target_m = None + target_modules = [] for n, m in model.named_modules(): if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: - target_m = (n, m) - break ## only find the first modulelist, may be not robust - for n, m in target_m[1].named_children(): - block_names.append(target_m[0] + "." + n) + target_modules.append((n, m)) + # break ## only find the first modulelist, may be not robust + for i,target_m in enumerate(target_modules): + block_names.append([]) + for n, m in target_m[1].named_children(): + block_names[i].append(target_m[0] + "." + n) return block_names @@ -232,7 +234,8 @@ def collect_minmax_scale(block): @torch.no_grad() -def sampling_inputs(input_ids, input_others, indices, seqlen, share_attention_mask_flag=False, input_dim=0): +def sampling_inputs(input_ids, input_others, indices, seqlen, + share_attention_mask_flag=False, not_share_position_ids_flag=False, input_dim=0): """Samples inputs based on the given indices and sequence length. Args: @@ -250,7 +253,8 @@ def sampling_inputs(input_ids, input_others, indices, seqlen, share_attention_ma current_input_others = {"positional_inputs": input_others["positional_inputs"]} for key in input_others.keys(): - if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key): + if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key) \ + or (not_share_position_ids_flag and "position_ids" in key): current_input_others[key] = None if input_others[key] is not None: current_input_others[key] = [input_others[key][i] for i in indices] @@ -556,12 +560,13 @@ def get_layer_names_in_block(model, supported_types=[torch.nn.Linear, transforme if isinstance(m, tuple(supported_types)): m.tmp_name = n layers_in_block = [] - block_names = get_block_names(model) - for block_name in block_names: - block = get_module(model, block_name) - for n, m in block.named_modules(): - if hasattr(m, "tmp_name"): - layers_in_block.append(m.tmp_name) + all_blocks = get_block_names(model) + for block_names in all_blocks: + for block_name in block_names: + block = get_module(model, block_name) + for n, m in block.named_modules(): + if hasattr(m, "tmp_name"): + layers_in_block.append(m.tmp_name) for n, m in model.named_modules(): if hasattr(m, "tmp_name"): delattr(m, "tmp_name") @@ -665,3 +670,4 @@ def dynamic_import_inference_linear(bits, group_size, backend): else: from auto_round_extension.cuda.qliner_triton import QuantLinear return QuantLinear + diff --git a/examples/multimodal-modeling/Llava/README.md b/examples/multimodal-modeling/Llava/README.md new file mode 100644 index 00000000..236b7681 --- /dev/null +++ b/examples/multimodal-modeling/Llava/README.md @@ -0,0 +1,94 @@ +Step-by-Step +============ + +This document presents step-by-step instructions for auto-round. +# Run Quantization on Multimodal Models + +In this example, we introduce an straight-forward way to execute quantization on some popular multimodal models such as LLaVA. + +## Install +If you are not using Linux, do NOT proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md). + +1. Clone this repository and navigate to LLaVA folder +```shell +git clone https://github.com/haotian-liu/LLaVA.git +cd LLaVA +``` + +2. Install Package +``` +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +## Download the calibration data + +Our calibration process resembles the official visual instruction tuning process. To align the official implementation of [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main?tab=readme-ov-file#visual-instruction-tuning) + +Please download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json), and download the images from constituting datasets: + +COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip), and unzip the image folder to any directory you desire. + +
+ +## 2. Run Examples +Enter into the examples folder and install lm-eval to run the evaluation +```bash +pip install -r requirements.txt +``` + +- **Default Settings:** +```bash +CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name liuhaotian/llava-v1.5-7b --bits 4 --group_size 128 +``` + +- **Speedup the tuning:** + +disable_low_gpu_mem_usage(more gpu memory) + +reduce the seqlen to 512(potential large accuracy drop) + +or combine them + +- **Enable quantized lm-head:** + +Currently only support in Intel xpu, however, we found the fake tuning could improve the accuracy is some scenarios. --disable_low_gpu_mem_usage is strongly recommended if the whole model could be loaded to the device, otherwise it will be quite slow to cache the inputs of lm-head. Another way is reducing nsamples,e.g. 128, to alleviate the issue. +```bash +CUDA_VISIBLE_DEVICES=0 python3 main.py --model_name liuhaotian/llava-v1.5-7b --bits 4 --group_size 128 --quant_lm_head --disable_low_gpu_mem_usage +``` + +- **Utilizing the AdamW Optimizer:** + +Include the flag `--adam`. Note that AdamW is less effective than sign gradient descent in many scenarios we tested. + +- **Running on Intel Gaudi2** +```bash +bash run_autoround_on_gaudi.sh +``` + + +## 4. Known Issues +* huggingface format model is not support yet, e.g. llava-1.5-7b-hf + + +## 5. Environment + +PyTorch 1.8 or higher version is needed + + +## Reference +If you find SignRound useful for your research, please cite our paper: +```bash +@article{cheng2023optimize, + title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs}, + author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao}, + journal={arXiv preprint arXiv:2309.05516}, + year={2023} +} +``` + + + + + + diff --git a/examples/multimodal-modeling/Llava/evaluation.py b/examples/multimodal-modeling/Llava/evaluation.py new file mode 100644 index 00000000..3cffbcc3 --- /dev/null +++ b/examples/multimodal-modeling/Llava/evaluation.py @@ -0,0 +1,309 @@ +import torch +import os +import json +from tqdm import tqdm +import shortuuid +import math + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path +from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator +from torch.utils.data import Dataset, DataLoader + +from PIL import Image + +from transformers import AutoProcessor, LlavaForConditionalGeneration + +class CustomDataset(Dataset): + def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode): + self.questions = questions + self.image_folder = image_folder + self.tokenizer = tokenizer + self.image_processor = image_processor + self.model_config = model_config + self.conv_mode = conv_mode + + def __getitem__(self, index): + line = self.questions[index] + image_file = line["image"] + qs = line["text"] + if self.model_config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + conv = conv_templates[self.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') + image_tensor = process_images([image], self.image_processor, self.model_config)[0] + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + + return input_ids, image_tensor, image.size + + def __len__(self): + return len(self.questions) + +# base evaluator +# class BaseEvaluator(object): +# def __init__(self, questions = None, images = None): +# # data related +# self.question_file = questions # the question file to be loaded +# self.image_folder = images # images to be loaded +# self.questions = None +# self.images = None +# self.answer_file = None # file to save the output answers +# # model related +# self.model = None +# self.model_name = None +# self.tokenizer = None +# self.image_processor = None +# self.context_len = None + +# def prepare_model(self): +# raise NotImplementedError + +# def prepare_data(self): +# raise NotImplementedError + +# def run_inference(self, model): +# raise NotImplementedError + +# def calcualate_benchmark(self, result, annotation): +# raise NotImplementedError + +class TextVQAEvaluator(object): + def __init__(self, question_file = None, image_folder = None, *args, **kwargs): + # data related + self.question_file = question_file # the question file to be loaded + self.image_folder = image_folder # images to be loaded + self.questions = None + self.images = None + self.answer_file = None # file to save the output answers + # model related + self.model = None + self.model_name = None + self.tokenizer = None + self.image_processor = None + self.external_args = kwargs + + def prepare_model(self, model_name_or_path = None, model_base = None): + model_path = os.path.expanduser(model_name_or_path) + self.model_name = get_model_name_from_path(model_path) + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, self.model_name) + + def prepare_data(self, num_chunks=1, chunk_idx=0, conv_mode = "vicuna_v1"): + # load textvqa dataloader + from llava.eval.model_vqa_loader import get_chunk, split_list, collate_fn + self.questions = [json.loads(q) for q in open(os.path.expanduser(self.question_file), "r")] + self.questions = get_chunk(self.questions, num_chunks, chunk_idx) + + dataset = CustomDataset(self.questions, self.image_folder, self.tokenizer, self.image_processor, self.model.config, conv_mode) + data_loader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=collate_fn) + return data_loader + + def run_inference(self, model_name_or_path, answer_file, temperature = 0): + self.prepare_model(model_name_or_path) + data_loader = self.prepare_data() + self.answer_file = answer_file + ans_file = open(self.answer_file, "w") + # run inference + for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, self.questions), total=len(self.questions)): + idx = line["question_id"] + cur_prompt = line["text"] + input_ids = input_ids.to(device='cuda', non_blocking=True) + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + image_sizes=image_sizes, + do_sample=True if temperature > 0 else False, + temperature=temperature, + top_p=None, + num_beams=1, + max_new_tokens=128, + use_cache=True) + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": self.model_name, + "metadata": {}}) + "\n") + ans_file.close() + + def prompt_processor(self, prompt): + if prompt.startswith('OCR tokens: '): + pattern = r"Question: (.*?) Short answer:" + match = re.search(pattern, prompt, re.DOTALL) + question = match.group(1) + elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: + if prompt.startswith('Reference OCR token:'): + question = prompt.split('\n')[1] + else: + question = prompt.split('\n')[0] + elif len(prompt.split('\n')) == 2: + question = prompt.split('\n')[0] + else: + assert False + return question.lower() + + def calcualate_benchmark(self, answer_file, annotation_file): + from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator + # load the result files + experiment_name = os.path.splitext(os.path.basename(answer_file))[0] + print(experiment_name) + annotations = json.load(open(annotation_file))['data'] + annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} + results = [json.loads(line) for line in open(answer_file)] + + pred_list = [] + for result in results: + annotation = annotations[(result['question_id'], self.prompt_processor(result['prompt']))] + pred_list.append({ + "pred_answer": result['text'], + "gt_answers": annotation['answers'], + }) + + evaluator = TextVQAAccuracyEvaluator() + print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) + +class POPEEvaluator(object): + def __init__(self, question_file = None, image_folder = None, *args, **kwargs): + # data related + self.question_file = question_file # the question file to be loaded + self.image_folder = image_folder # images to be loaded + self.questions = None + self.images = None + self.answer_file = None # file to save the output answers + # model related + self.model = None + self.model_name = None + self.tokenizer = None + self.image_processor = None + self.external_args = kwargs + + def prepare_model(self, model_name_or_path = None, model_base = None): + model_path = os.path.expanduser(model_name_or_path) + self.model_name = get_model_name_from_path(model_path) + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, self.model_name) + + def prepare_data(self, num_chunks=1, chunk_idx=0, conv_mode = "vicuna_v1"): + # load textvqa dataloader + from llava.eval.model_vqa_loader import get_chunk, split_list, collate_fn + self.questions = [json.loads(q) for q in open(os.path.expanduser(self.question_file), "r")] + self.questions = get_chunk(self.questions, num_chunks, chunk_idx) + + dataset = CustomDataset(self.questions, self.image_folder, self.tokenizer, self.image_processor, self.model.config, conv_mode) + data_loader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=collate_fn) + return data_loader + + def run_inference(self, model_name_or_path, answer_file, temperature = 0): + self.prepare_model(model_name_or_path) + data_loader = self.prepare_data() + self.answer_file = answer_file + ans_file = open(self.answer_file, "w") + # run inference + for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, self.questions), total=len(self.questions)): + idx = line["question_id"] + cur_prompt = line["text"] + input_ids = input_ids.to(device='cuda', non_blocking=True) + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + image_sizes=image_sizes, + do_sample=True if temperature > 0 else False, + temperature=temperature, + top_p=None, + num_beams=1, + max_new_tokens=128, + use_cache=True) + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": self.model_name, + "metadata": {}}) + "\n") + ans_file.close() + + def calculate_accuracy(self, answers, label_file): + label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] + + for answer in answers: + text = answer['text'] + + # Only keep the first sentence + if text.find('.') != -1: + text = text.split('.')[0] + + text = text.replace(',', '') + words = text.split(' ') + if 'No' in words or 'not' in words or 'no' in words: + answer['text'] = 'no' + else: + answer['text'] = 'yes' + + for i in range(len(label_list)): + if label_list[i] == 'no': + label_list[i] = 0 + else: + label_list[i] = 1 + + pred_list = [] + for answer in answers: + if answer['text'] == 'no': + pred_list.append(0) + else: + pred_list.append(1) + + pos = 1 + neg = 0 + yes_ratio = pred_list.count(1) / len(pred_list) + + TP, TN, FP, FN = 0, 0, 0, 0 + for pred, label in zip(pred_list, label_list): + if pred == pos and label == pos: + TP += 1 + elif pred == pos and label == neg: + FP += 1 + elif pred == neg and label == neg: + TN += 1 + elif pred == neg and label == pos: + FN += 1 + + print('TP\tFP\tTN\tFN\t') + print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) + + precision = float(TP) / float(TP + FP) + recall = float(TP) / float(TP + FN) + f1 = 2*precision*recall / (precision + recall) + acc = (TP + TN) / (TP + TN + FP + FN) + print('Accuracy: {}'.format(acc)) + print('Precision: {}'.format(precision)) + print('Recall: {}'.format(recall)) + print('F1 score: {}'.format(f1)) + print('Yes ratio: {}'.format(yes_ratio)) + print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) + + def calcualate_benchmark(self, question_file, answer_file, annotation_dir): + questions = [json.loads(line) for line in open(question_file)] + questions = {question['question_id']: question for question in questions} + answers = [json.loads(q) for q in open(answer_file)] + for file in os.listdir(annotation_dir): + assert file.startswith('coco_pope_') + assert file.endswith('.json') + category = file[10:-5] + cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] + print('Category: {}, # samples: {}'.format(category, len(cur_answers))) + self.calculate_accuracy(cur_answers, os.path.join(annotation_dir, file)) + print("====================================") diff --git a/examples/multimodal-modeling/Llava/main.py b/examples/multimodal-modeling/Llava/main.py new file mode 100644 index 00000000..7ce67ec5 --- /dev/null +++ b/examples/multimodal-modeling/Llava/main.py @@ -0,0 +1,371 @@ +import argparse +import sys + +sys.path.insert(0, '../..') +sys.path.insert(0, '.') +parser = argparse.ArgumentParser() +import torch +import os +import transformers + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.use_deterministic_algorithms(True, warn_only=True) +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel +from transformers import set_seed + +import re + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +import copy +from PIL import Image +import json +import math +import shortuuid +from torch.utils.data import Dataset, DataLoader +# from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +# from llava.conversation import conv_templates, SeparatorStyle +# from llava.utils import disable_torch_init +# from llava.mm_utils import tokenizer_image_token, process_images +# from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator +from llava.mm_utils import get_model_name_from_path +from llava.train.train import preprocess, preprocess_multimodal +from llava.model.builder import load_pretrained_model + +# from transformers import AutoProcessor, LlavaForConditionalGeneration + +class CustomDataset(Dataset): # for llava tuning + # much refer to https://github.com/haotian-liu/LLaVA/blob/main/llava/train/train.py + def __init__(self, list_data_dict, image_folder, tokenizer, image_processor, args): + self.list_data_dict = list_data_dict + self.image_folder = image_folder + self.tokenizer = tokenizer + self.image_processor = image_processor + self.args = args + + def __getitem__(self, index): + sources = self.list_data_dict[index] + + # image + image_file = os.path.basename(sources["image"]) + image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') + image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([sources["conversations"]]), # a list + self.args, + ) + + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[index]) + ) + if isinstance(index, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + data_dict['image'] = image + return data_dict["input_ids"], data_dict["image"], data_dict["image"].size() + + def __len__(self): + return len(self.list_data_dict) + +@torch.no_grad() +def collate_fn(batch): + input_ids, image_tensors, image_sizes = zip(*batch) + input_ids = torch.stack(input_ids, dim=0) + image_tensors = torch.stack(image_tensors, dim=0) + return input_ids, image_tensors, image_sizes + +def create_data_loader(dataset, batch_size=1): + assert batch_size == 1, "batch_size must be 1" + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False, collate_fn=collate_fn) + return data_loader + +if __name__ == '__main__': + + parser.add_argument( + "--model_name", default="facebook/opt-125m" + ) + + parser.add_argument("--bits", default=4, type=int, + help="number of bits") + + parser.add_argument("--group_size", default=128, type=int, + help="group size") + + parser.add_argument("--train_bs", default=1, type=int, + help="train batch size") + + parser.add_argument("--eval_bs", default=4, type=int, + help="eval batch size") + + parser.add_argument("--device", default="auto", type=str, + help="The device to be used for tuning. The default is set to auto/None," + "allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.") + + parser.add_argument("--sym", action='store_true', + help=" sym quantization") + + parser.add_argument("--iters", default=200, type=int, + help=" iters") + + parser.add_argument("--lr", default=None, type=float, + help="learning rate, if None, it will be set to 1.0/iters automatically") + + parser.add_argument("--minmax_lr", default=None, type=float, + help="minmax learning rate, if None,it will beset to be the same with lr") + + parser.add_argument("--seed", default=42, type=int, + help="seed") + + parser.add_argument("--eval_fp16_baseline", action='store_true', + help="whether to eval FP16 baseline") + + parser.add_argument("--adam", action='store_true', + help="adam") + + parser.add_argument("--seqlen", default=2048, type=int, + help="sequence length") + + parser.add_argument("--gradient_accumulate_steps", default=1, type=int, help="gradient accumulate steps") + + parser.add_argument("--nblocks", default=1, type=int, help="num of blocks to tune together") + + parser.add_argument("--nsamples", default=512, type=int, + help="number of samples") + + parser.add_argument("--low_gpu_mem_usage", action='store_true', + help="low_gpu_mem_usage is deprecated") + + parser.add_argument("--deployment_device", default='fake', type=str, + help="targeted inference acceleration platform,The options are 'fake', 'cpu', 'gpu' and 'xpu'." + "default to 'fake', indicating that it only performs fake quantization and won't be exported to any device.") + + parser.add_argument("--scale_dtype", default='fp16', + help="which scale data type to use for quantization, 'fp16', 'fp32' or 'bf16'.") + + parser.add_argument("--tasks", + default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \ + "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge,wikitext2,ptb-new,c4-new", + help="lm-eval tasks for lm_eval version 0.4") + + parser.add_argument("--output_dir", default="./tmp_autoround", type=str, + help="Where to store the final model.") + + parser.add_argument("--disable_eval", action='store_true', + help="Whether to do lmeval evaluation.") + + parser.add_argument("--disable_amp", action='store_true', + help="disable amp") + + parser.add_argument("--disable_low_gpu_mem_usage", action='store_true', + help="disable low_gpu_mem_usage") + + parser.add_argument("--disable_minmax_tuning", action='store_true', + help="whether disable enable weight minmax tuning") + + parser.add_argument("--disable_trust_remote_code", action='store_true', + help="Whether to disable trust_remote_code") + + parser.add_argument("--disable_quanted_input", action='store_true', + help="whether to disuse the output of quantized block to tune the next block") + + parser.add_argument("--quant_lm_head", action='store_true', + help="quant_lm_head") + + parser.add_argument("--model_dtype", default=None, type=str, + help="force to convert the dtype, some backends supports fp16 dtype better") + + # ========== Calibration Datasets ============= + parser.add_argument("--mm-use-im-start-end", type=bool, default=False) + + parser.add_argument("--is_multimodal", type=bool, default=False, + help="To determine whether the preprocessing should handle multimodal data.") + + parser.add_argument("--image_folder", default="coco", type=str, + help="The dataset for quantization training. It can be a custom one.") + + parser.add_argument("--question_file", default=None, type=str, + help="The dataset for quantization training. It can be a custom one.") + + # parser.add_argument("--dataset", default=None, type=str, + # help="The dataset for quantization training. It can be a custom one.") + + # ================= Evaluation Related ===================== + parser.add_argument("--eval-question-file", type=str, default="tables/question.jsonl") + + parser.add_argument("--eval-image-folder", type=str) + + parser.add_argument('--eval-result-file', type=str) + + parser.add_argument('--eval-annotation-file', type=str) + + args = parser.parse_args() + + set_seed(args.seed) + tasks = args.tasks + + model_name = args.model_name + if model_name[-1] == "/": + model_name = model_name[:-1] + print(model_name, flush=True) + + from auto_round.utils import detect_device + + device_str = detect_device(args.device) + torch_dtype = "auto" + if "hpu" in device_str: + torch_dtype = torch.bfloat16 + torch_device = torch.device(device_str) + model_path = args.model_name + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base=None, model_name=model_name, + torch_dtype=torch_dtype) + questions = json.load(open(args.question_file, "r")) + dataset = CustomDataset(questions, args.image_folder, tokenizer, image_processor, args) + dataloader = create_data_loader(dataset, args.train_bs) + + from auto_round import (AutoRound, + AutoAdamRound) + + model = model.eval() + + if args.model_dtype != None: + if args.model_dtype == "float16" or args.model_dtype == "fp16": + model = model.to(torch.float16) + if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16": + model = model.to(torch.bfloat16) + + seqlen = args.seqlen + if hasattr(tokenizer, "model_max_length"): + if tokenizer.model_max_length < seqlen: + print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length", + flush=True) + seqlen = min(seqlen, tokenizer.model_max_length) + args.seqlen = seqlen + + excel_name = f"{model_name}_{args.bits}_{args.group_size}" + pt_dtype = torch.float16 + if (hasattr(model, 'config') and (model.dtype is torch.bfloat16 or model.config.torch_dtype is torch.bfloat16)): + dtype = 'bfloat16' + pt_dtype = torch.bfloat16 + else: + if str(args.device) != "cpu": + pt_dtype = torch.float16 + dtype = 'float16' + else: + pt_dtype = torch.float32 + dtype = 'float32' + + if args.eval_fp16_baseline: + if args.disable_low_gpu_mem_usage: + model = model.to(torch_device) + from mm_evaluation import TextVQAEvaluator + evaluator = TextVQAEvaluator( + model, + tokenizer, + image_processor, + args.eval_image_folder, + args.eval_question_file, + args.eval_annotation_file, + model_name = model_name + ) + evaluator.run_evaluate(result_file = args.eval_result_file) + evaluator.calculate_accuracy(result_file = args.eval_result_file) + exit() + + round = AutoRound + if args.adam: + round = AutoAdamRound + + weight_config = {} + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D): + if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: + weight_config[n] = {"data_type": "fp"} + print( + f"{n} will not be quantized due to its shape not being divisible by 32, resulting in an exporting issue to autogptq") + lm_head_layer_name = "lm_head" + for n, _ in model.named_modules(): + lm_head_layer_name = n + if args.quant_lm_head: + from transformers import AutoConfig + + config = model.config + if config.tie_word_embeddings and hasattr(model, "_tied_weights_keys"): + tied_keys = model._tied_weights_keys + for item in tied_keys: + if lm_head_layer_name in item: ##TODO extend to encoder-decoder layer, seq classification model + args.quant_lm_head = False + print( + f"warning, disable quant_lm_head as quantizing lm_head with tied weights has not been " + f"supported currently") + break + if args.quant_lm_head: + weight_config[lm_head_layer_name] = {"data_type": "int"} + transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]] + if transformers_version[0] == 4 and transformers_version[1] < 38: + error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization." + raise EnvironmentError(error_message) + + if args.quant_lm_head and not args.disable_low_gpu_mem_usage: + print(f"warning, disable_low_gpu_mem_usage is strongly recommended if the whole model could be loaded to " + f"gpu") + deployment_device = args.deployment_device.split(',') + gpu_format = "auto_gptq" + if 'gpu' in deployment_device: + if lm_head_layer_name in weight_config.keys() and weight_config[lm_head_layer_name]["data_type"] == "int": + gpu_format = "auto_round" + + if "autoround" in deployment_device or "auto-round" in deployment_device or "auto_round" in deployment_device: + gpu_format = "auto_round" + + autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs, + dataset=dataloader, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, + minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, + amp=not args.disable_amp, nsamples=args.nsamples, + low_gpu_mem_usage=not args.disable_low_gpu_mem_usage, + seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, + scale_dtype=args.scale_dtype, weight_config=weight_config, + enable_minmax_tuning=not args.disable_minmax_tuning, multimodal=True) + model, _ = autoround.quantize() + model_name = args.model_name.rstrip("/") + + model.eval() + if args.device != "cpu": + torch.cuda.empty_cache() + + export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}" + output_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}-qdq" + + inplace = True if len(deployment_device) < 2 else False + if 'gpu' in deployment_device or "auto_round" in gpu_format or "auto-round" in gpu_format: + autoround.save_quantized(f'{export_dir}-gpu', format=gpu_format, use_triton=True, inplace=inplace) + if 'xpu' in deployment_device: + autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace, + compression_dtype=torch.int8, compression_dim=0, use_optimum_format=False, + device="xpu") + if "cpu" in deployment_device: + autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace) + if "fake" in deployment_device: + model = model.to("cpu") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + if not args.disable_eval and "fake" in deployment_device: ##support autogptq real eval later + model = model.half() + model = model.to(torch_device) + from mm_evaluation import TextVQAEvaluator + evaluator = TextVQAEvaluator( + model, + tokenizer, + image_processor, + args.eval_image_folder, + args.eval_question_file, + args.eval_annotation_file, + model_name = model_name + ) + evaluator.run_evaluate(result_file = args.eval_result_file) + evaluator.calculate_accuracy(result_file = args.eval_result_file) + diff --git a/examples/multimodal-modeling/Llava/mm_evaluation/__init__.py b/examples/multimodal-modeling/Llava/mm_evaluation/__init__.py new file mode 100644 index 00000000..42c010e5 --- /dev/null +++ b/examples/multimodal-modeling/Llava/mm_evaluation/__init__.py @@ -0,0 +1 @@ +from .textvqa import TextVQAEvaluator \ No newline at end of file diff --git a/examples/multimodal-modeling/Llava/mm_evaluation/textvqa.py b/examples/multimodal-modeling/Llava/mm_evaluation/textvqa.py new file mode 100644 index 00000000..2dd38460 --- /dev/null +++ b/examples/multimodal-modeling/Llava/mm_evaluation/textvqa.py @@ -0,0 +1,201 @@ +import sys +import os +import math +from tqdm import tqdm +import shortuuid +import json +import re + +from PIL import Image + +import torch +from torch.utils.data import Dataset, DataLoader +from llava.utils import disable_torch_init +from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator +from llava.mm_utils import tokenizer_image_token, process_images +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + +def collate_fn(batch): + input_ids, image_tensors, image_sizes = zip(*batch) + input_ids = torch.stack(input_ids, dim=0) + image_tensors = torch.stack(image_tensors, dim=0) + return input_ids, image_tensors, image_sizes + +class CustomDatasetTextVQA(Dataset): + def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode): + self.questions = questions + self.image_folder = image_folder + self.tokenizer = tokenizer + self.image_processor = image_processor + self.model_config = model_config + self.conv_mode = conv_mode + + def __getitem__(self, index): + # import pdb;pdb.set_trace() + line = self.questions[index] + image_file = line["image"] + qs = line["text"] + if self.model_config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs + else: + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + conv = conv_templates[self.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') + image_tensor = process_images([image], self.image_processor, self.model_config)[0] + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') + + return input_ids, image_tensor, image.size + + def __len__(self): + return len(self.questions) + +class TextVQAEvaluator(object): + def __init__( + self, + model, + tokenizer, + image_processor, + image_folder, + question_file, + annotation_file, + **kwargs + ): + super(TextVQAEvaluator, self).__init__() + self.model = model + self.tokenizer = tokenizer + self.image_processor = image_processor + self.image_folder = image_folder + self.question_file = question_file + self.annotation_file = annotation_file + # follow parameters can be set as default value. + self.model_name = kwargs.get("model_name", "llava") + self.conv_mode = kwargs.get("conv_mode", "vicuna_v1") + self.num_chunks = kwargs.get("num_chunks", 1) + self.chunk_idx = kwargs.get("chunk_idx", 0) + self.temperature = kwargs.get("temperature", 0) + self.top_p = kwargs.get("top_p", None) + self.num_beams = kwargs.get("num_beams", 1) + self.max_new_tokens = kwargs.get("max_new_tokens", 128) + + if 'plain' in self.model_name and 'finetune' not in self.model_name.lower() and 'mmtag' not in self.conv_mode: + self.conv_mode = self.conv_mode + '_mmtag' + print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {self.conv_mode}.') + + def create_dataloader(self): + questions = [json.loads(q) for q in open(os.path.expanduser(self.question_file), "r")] + questions = get_chunk(questions, self.num_chunks, self.chunk_idx) + dataset = CustomDatasetTextVQA(questions, self.image_folder, self.tokenizer, self.image_processor, self.model.config, self.conv_mode) + data_loader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=collate_fn) + return data_loader, questions + + def run_evaluate(self, result_file = None): + disable_torch_init() + dataloader, questions = self.create_dataloader() + result_file = os.path.expanduser(result_file) + os.makedirs(os.path.dirname(result_file), exist_ok=True) + res_file = open(result_file, "w") + for (input_ids, image_tensor, image_sizes), line in tqdm(zip(dataloader, questions), total=len(questions)): + idx = line["question_id"] + cur_prompt = line["text"] + + input_ids = input_ids.to(device='cuda', non_blocking=True) + + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + image_sizes=image_sizes, + do_sample=True if self.temperature > 0 else False, + temperature=self.temperature, + top_p=self.top_p, + num_beams=self.num_beams, + max_new_tokens=self.max_new_tokens, + use_cache=True) + + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + + ans_id = shortuuid.uuid() + res_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": self.model_name, + "metadata": {}}) + "\n") + res_file.close() + + def prompt_processor(self, prompt): + if prompt.startswith('OCR tokens: '): + pattern = r"Question: (.*?) Short answer:" + match = re.search(pattern, prompt, re.DOTALL) + question = match.group(1) + elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: + if prompt.startswith('Reference OCR token:'): + question = prompt.split('\n')[1] + else: + question = prompt.split('\n')[0] + elif len(prompt.split('\n')) == 2: + question = prompt.split('\n')[0] + else: + assert False + + return question.lower() + + def calculate_accuracy(self, result_file = None): + experiment_name = os.path.splitext(os.path.basename(result_file))[0] + print(experiment_name) + annotations = json.load(open(self.annotation_file))['data'] + annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} + results = [json.loads(line) for line in open(result_file)] + + pred_list = [] + for result in results: + annotation = annotations[(result['question_id'], self.prompt_processor(result['prompt']))] + pred_list.append({ + "pred_answer": result['text'], + "gt_answers": annotation['answers'], + }) + + evaluator = TextVQAAccuracyEvaluator() + print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) + + + +# results + + + +# def eval_single(annotation_file, result_file): +# experiment_name = os.path.splitext(os.path.basename(result_file))[0] +# print(experiment_name) +# annotations = json.load(open(annotation_file))['data'] +# annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} +# results = [json.loads(line) for line in open(result_file)] + +# pred_list = [] +# for result in results: +# annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] +# pred_list.append({ +# "pred_answer": result['text'], +# "gt_answers": annotation['answers'], +# }) + +# evaluator = TextVQAAccuracyEvaluator() +# print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) + + diff --git a/examples/multimodal-modeling/Llava/run_autoround.sh b/examples/multimodal-modeling/Llava/run_autoround.sh new file mode 100644 index 00000000..e35c096a --- /dev/null +++ b/examples/multimodal-modeling/Llava/run_autoround.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -x +device=0 + +CUDA_VISIBLE_DEVICES=$device \ +python3 main.py \ +--model_name=Qwen/Qwen-VL \ +--bits 4 \ +--group_size 128 \ +--iters 200 \ +--deployment_device 'autoround' \ +--image_folder /path/to/coco/images/train2017/ \ +--question_file=/path/to/LLaVA-Instruct-150K/llava_v1_5_mix665k.json \ +--eval-question-file=/path/to/textvqa/llava_textvqa_val_v051_ocr.jsonl \ +--eval-image-folder=/path/to/textvqa/train_images \ +--eval-annotation-file=/path/to/textvqa/TextVQA_0.5.1_val.json \ +--eval-result-file "./tmp_autoround" \ +--output_dir "./tmp_autoround" \ No newline at end of file diff --git a/examples/multimodal-modeling/Qwen-VL/main.py b/examples/multimodal-modeling/Qwen-VL/main.py new file mode 100644 index 00000000..f6a125a3 --- /dev/null +++ b/examples/multimodal-modeling/Qwen-VL/main.py @@ -0,0 +1,484 @@ +import argparse +import sys + +sys.path.insert(0, '../..') +sys.path.insert(0, '.') +parser = argparse.ArgumentParser() +import torch +import os +import transformers + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" +torch.use_deterministic_algorithms(True, warn_only=True) +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel +from transformers import set_seed + +import re + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +import copy +from PIL import Image +import json +import math +import shortuuid +from torch.utils.data import Dataset, DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformers.generation import GenerationConfig +import torch +from auto_round.utils import convert_dtype_torch2str +from typing import Dict, Optional, List +OLD_IMAGE_TOKEN = '' +DEFAULT_IM_START_TOKEN = '' +DEFAULT_IM_END_TOKEN = '' +from transformers.trainer_pt_utils import LabelSmoother +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +import inspect + +def DataFormating(raw_data, image_folder): + for source in raw_data: + source_inputs = source['conversations'] + for sentence in source_inputs: + if OLD_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(OLD_IMAGE_TOKEN, '').strip() + sentence['value'] = OLD_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + replace_img = os.path.join(image_folder, os.path.basename(source["image"])) + replace_token = DEFAULT_IM_START_TOKEN + replace_img + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(OLD_IMAGE_TOKEN, replace_token) + return raw_data + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, + max_len: int, + system_message: str = "You are a helpful assistant." +) -> Dict: + roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} + + im_start = tokenizer.im_start_id + im_end = tokenizer.im_end_id + nl_tokens = tokenizer('\n').input_ids + _system = tokenizer('system').input_ids + nl_tokens + _user = tokenizer('user').input_ids + nl_tokens + _assistant = tokenizer('assistant').input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["user"]: + source = source[1:] + + input_id, target = [], [] + system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens + input_id += system + target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens + assert len(input_id) == len(target) + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + _input_id = tokenizer(role).input_ids + nl_tokens + \ + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens + input_id += _input_id + if role == '<|im_start|>user': + _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens + elif role == '<|im_start|>assistant': + _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ + _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens + else: + raise NotImplementedError + target += _target + assert len(input_id) == len(target) + input_id += [tokenizer.pad_token_id] * (max_len - len(input_id)) + target += [IGNORE_TOKEN_ID] * (max_len - len(target)) + input_ids.append(input_id[:max_len]) + targets.append(target[:max_len]) + if i >= 512: + break + input_ids = torch.tensor(input_ids, dtype=torch.int) + targets = torch.tensor(targets, dtype=torch.int) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + +class SupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int): + super(SupervisedDataset, self).__init__() + + print("Formatting inputs...") + sources = [example["conversations"] for example in raw_data] + data_dict = preprocess(sources, tokenizer, max_len) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.attention_mask = data_dict["attention_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict( + input_ids=self.input_ids[i], + labels=self.labels[i], + attention_mask=self.attention_mask[i], + ) + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + self.max_len = max_len + + print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + +from transformers.trainer_utils import RemoveColumnsCollator +from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +def set_signature_columns_if_needed(model): + # Inspect model forward signature to keep only the arguments it accepts. + model_to_inspect = model + signature = inspect.signature(model_to_inspect.forward) + signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + signature_columns += list(set(["label", "label_ids", 'labels'])) + return signature_columns + +def get_collator_with_removed_columns(model, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + signature_columns = set_signature_columns_if_needed(model) + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + description=description, + model_name=model.__class__.__name__, + ) + return remove_columns_collator + +def get_train_dataloader(train_dataset, model, data_collator, train_batch_size, num_workers) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + data_collator = get_collator_with_removed_columns(model, data_collator, description="training") + + dataloader_params = { + "batch_size": train_batch_size, + "collate_fn": data_collator, + "num_workers": num_workers, + } + + return DataLoader(train_dataset, **dataloader_params) + + +if __name__ == '__main__': + + parser.add_argument( + "--model_name", default="facebook/opt-125m" + ) + + parser.add_argument("--bits", default=4, type=int, + help="number of bits") + + parser.add_argument("--group_size", default=128, type=int, + help="group size") + + parser.add_argument("--train_bs", default=1, type=int, + help="train batch size") + + parser.add_argument("--eval_bs", default=4, type=int, + help="eval batch size") + + parser.add_argument("--device", default="auto", type=str, + help="The device to be used for tuning. The default is set to auto/None," + "allowing for automatic detection. Currently, device settings support CPU, GPU, and HPU.") + + parser.add_argument("--sym", action='store_true', + help=" sym quantization") + + parser.add_argument("--iters", default=200, type=int, + help=" iters") + + parser.add_argument("--lr", default=None, type=float, + help="learning rate, if None, it will be set to 1.0/iters automatically") + + parser.add_argument("--minmax_lr", default=None, type=float, + help="minmax learning rate, if None,it will beset to be the same with lr") + + parser.add_argument("--seed", default=42, type=int, + help="seed") + + parser.add_argument("--eval_fp16_baseline", action='store_true', + help="whether to eval FP16 baseline") + + parser.add_argument("--adam", action='store_true', + help="adam") + + parser.add_argument("--seqlen", default=2048, type=int, + help="sequence length") + + parser.add_argument("--gradient_accumulate_steps", default=1, type=int, help="gradient accumulate steps") + + parser.add_argument("--nblocks", default=1, type=int, help="num of blocks to tune together") + + parser.add_argument("--nsamples", default=512, type=int, + help="number of samples") + + parser.add_argument("--low_gpu_mem_usage", action='store_true', + help="low_gpu_mem_usage is deprecated") + + parser.add_argument("--deployment_device", default='fake', type=str, + help="targeted inference acceleration platform,The options are 'fake', 'cpu', 'gpu' and 'xpu'." + "default to 'fake', indicating that it only performs fake quantization and won't be exported to any device.") + + parser.add_argument("--scale_dtype", default='fp16', + help="which scale data type to use for quantization, 'fp16', 'fp32' or 'bf16'.") + + parser.add_argument("--tasks", + default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \ + "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge,wikitext2,ptb-new,c4-new", + help="lm-eval tasks for lm_eval version 0.4") + + parser.add_argument("--output_dir", default="./tmp_autoround", type=str, + help="Where to store the final model.") + + parser.add_argument("--disable_eval", action='store_true', + help="Whether to do lmeval evaluation.") + + parser.add_argument("--disable_amp", action='store_true', + help="disable amp") + + parser.add_argument("--disable_low_gpu_mem_usage", action='store_true', + help="disable low_gpu_mem_usage") + + parser.add_argument("--disable_minmax_tuning", action='store_true', + help="whether disable enable weight minmax tuning") + + parser.add_argument("--disable_trust_remote_code", action='store_true', + help="Whether to disable trust_remote_code") + + parser.add_argument("--disable_quanted_input", action='store_true', + help="whether to disuse the output of quantized block to tune the next block") + + parser.add_argument("--quant_lm_head", action='store_true', + help="quant_lm_head") + + parser.add_argument("--model_dtype", default=None, type=str, + help="force to convert the dtype, some backends supports fp16 dtype better") + + parser.add_argument("--model_max_length", default=2048, type=int, + help="") + + # ========== Calibration Datasets ============= + parser.add_argument("--image_folder", default="coco", type=str, + help="The dataset for quantization training. It can be a custom one.") + + parser.add_argument("--question_file", default=None, type=str, + help="The dataset for quantization training. It can be a custom one.") + + # parser.add_argument("--dataset", default=None, type=str, + # help="The dataset for quantization training. It can be a custom one.") + + # ================= Evaluation Related ===================== + parser.add_argument("--eval-path", type=str, default="") + + parser.add_argument("--eval-dataset", type=str, default="textvqa_val") + + args = parser.parse_args() + + set_seed(args.seed) + tasks = args.tasks + + model_name = args.model_name + if model_name[-1] == "/": + model_name = model_name[:-1] + print(model_name, flush=True) + + from auto_round.utils import detect_device + + device_str = detect_device(args.device) + torch_dtype = "auto" + if "hpu" in device_str: + torch_dtype = torch.bfloat16 + torch_device = torch.device(device_str) + + torch.manual_seed(1234) + model_name = args.model_name + questions = json.load(open(args.question_file, "r")) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code, + padding_side="right", use_fast=False) + tokenizer.pad_token_id = tokenizer.eod_id + seqlen = args.seqlen + if hasattr(tokenizer, "model_max_length"): + if tokenizer.model_max_length < seqlen: + print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length", + flush=True) + seqlen = min(seqlen, tokenizer.model_max_length) + args.seqlen = seqlen + + config = transformers.AutoConfig.from_pretrained( + model_name, + trust_remote_code=True, + ) + config.use_cache = False + if args.model_dtype != None: + if args.model_dtype == "float16" or args.model_dtype == "fp16": + torch_device = torch.float16 + if args.model_dtype == "bfloat16" or args.model_dtype == "bfp16": + torch_device = torch.bfloat16 + dtype_abd = convert_dtype_torch2str(torch_dtype) + if dtype_abd == "bf16": + model = AutoModelForCausalLM.from_pretrained(args.model_name, config=config, trust_remote_code=not args.disable_trust_remote_code, bf16=True).eval() + elif dtype_abd == "fp16": + model = AutoModelForCausalLM.from_pretrained(args.model_name, config=config, trust_remote_code=not args.disable_trust_remote_code, fp16=True).eval() + else: + model = AutoModelForCausalLM.from_pretrained(args.model_name, config=config, trust_remote_code=not args.disable_trust_remote_code).eval() + raw_data = DataFormating(questions, args.image_folder) + # dataset = SupervisedDataset(raw_data, tokenizer, max_len=tokenizer.model_max_length) + dataset = LazySupervisedDataset(raw_data, tokenizer, max_len=min(args.seqlen, tokenizer.model_max_length)) + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + dataloader = get_train_dataloader(dataset, model, default_collator, train_batch_size=1, num_workers=0) + + from auto_round import (AutoRound, + AutoAdamRound) + + # model = model.eval() + seqlen = args.seqlen + + if args.eval_fp16_baseline: + if args.disable_low_gpu_mem_usage: + model = model.to(torch_device) + from mm_evaluation.evaluate_vqa import textVQA_evaluation + evaluator = textVQA_evaluation( + model, + dataset_name=args.eval_dataset, + dataset_path=args.eval_path, + tokenizer=tokenizer, + batch_size=args.eval_bs + ) + exit() + + round = AutoRound + if args.adam: + round = AutoAdamRound + + weight_config = {} + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D): + if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: + weight_config[n] = {"data_type": "fp"} + print( + f"{n} will not be quantized due to its shape not being divisible by 32, resulting in an exporting issue to autogptq") + lm_head_layer_name = "lm_head" + for n, _ in model.named_modules(): + lm_head_layer_name = n + if args.quant_lm_head: + from transformers import AutoConfig + + config = model.config + if config.tie_word_embeddings and hasattr(model, "_tied_weights_keys"): + tied_keys = model._tied_weights_keys + for item in tied_keys: + if lm_head_layer_name in item: ##TODO extend to encoder-decoder layer, seq classification model + args.quant_lm_head = False + print( + f"warning, disable quant_lm_head as quantizing lm_head with tied weights has not been " + f"supported currently") + break + if args.quant_lm_head: + weight_config[lm_head_layer_name] = {"data_type": "int"} + transformers_version = [int(item) for item in transformers.__version__.split('.')[:2]] + if transformers_version[0] == 4 and transformers_version[1] < 38: + error_message = "Please upgrade transformers>=4.38.0 to support lm-head quantization." + raise EnvironmentError(error_message) + + if args.quant_lm_head and not args.disable_low_gpu_mem_usage: + print(f"warning, disable_low_gpu_mem_usage is strongly recommended if the whole model could be loaded to " + f"gpu") + deployment_device = args.deployment_device.split(',') + gpu_format = "auto_gptq" + if 'gpu' in deployment_device: + if lm_head_layer_name in weight_config.keys() and weight_config[lm_head_layer_name]["data_type"] == "int": + gpu_format = "auto_round" + + if "autoround" in deployment_device or "auto-round" in deployment_device or "auto_round" in deployment_device: + gpu_format = "auto_round" + + autoround = round(model, tokenizer, args.bits, args.group_size, sym=args.sym, batch_size=args.train_bs, + dataset=dataloader, seqlen=seqlen, nblocks=args.nblocks, iters=args.iters, lr=args.lr, + minmax_lr=args.minmax_lr, enable_quanted_input=not args.disable_quanted_input, device=device_str, + amp=not args.disable_amp, nsamples=args.nsamples, + low_gpu_mem_usage=not args.disable_low_gpu_mem_usage, + seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps, + scale_dtype=args.scale_dtype, weight_config=weight_config, + enable_minmax_tuning=not args.disable_minmax_tuning, multimodal=True) + model, _ = autoround.quantize() + model_name = args.model_name.rstrip("/") + + model.eval() + if args.device != "cpu": + torch.cuda.empty_cache() + + export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}" + output_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}-qdq" + inplace = True if len(deployment_device) < 2 else False + if 'gpu' in deployment_device or "auto_round" in gpu_format or "auto-round" in gpu_format: + autoround.save_quantized(f'{export_dir}-gpu', format=gpu_format, use_triton=True, inplace=inplace) + if 'xpu' in deployment_device: + autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace, + compression_dtype=torch.int8, compression_dim=0, use_optimum_format=False, + device="xpu") + if "cpu" in deployment_device: + autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace) + if "fake" in deployment_device: + model = model.to("cpu") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + if not args.disable_eval and "fake" in deployment_device: ## TODO + model = model.half() + model = model.to(torch_device) + from mm_evaluation.evaluate_vqa import textVQA_evaluation + evaluator = textVQA_evaluation( + model, + dataset_name=args.eval_dataset, + dataset_path=args.eval_path, + tokenizer=tokenizer, + batch_size=args.eval_bs + ) \ No newline at end of file diff --git a/examples/multimodal-modeling/Qwen-VL/mm_evaluation/__init__.py b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/__init__.py new file mode 100644 index 00000000..dbe21ceb --- /dev/null +++ b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/__init__.py @@ -0,0 +1,4 @@ +if __name__ == "__main__": + import sys + sys.path.insert(0, './') + diff --git a/examples/multimodal-modeling/Qwen-VL/mm_evaluation/evaluate_vqa.py b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/evaluate_vqa.py new file mode 100644 index 00000000..e141e059 --- /dev/null +++ b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/evaluate_vqa.py @@ -0,0 +1,421 @@ +import argparse +import itertools +import json +import os +import random +import time +from functools import partial +from typing import Optional + +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from .vqa import VQA +from .vqa_eval import VQAEval + +# This code is much refer to https://github.com/cognitedata/Qwen-VL-finetune/blob/master/eval_mm/evaluate_vqa.py + +ds_collections = { + 'vqav2_val': { + 'train': 'data/vqav2/vqav2_train.jsonl', + 'test': 'data/vqav2/vqav2_val.jsonl', + 'question': 'data/vqav2/v2_OpenEnded_mscoco_val2014_questions.json', + 'annotation': 'data/vqav2/v2_mscoco_val2014_annotations.json', + 'metric': 'vqa_score', + 'max_new_tokens': 10, + }, + 'vqav2_testdev': { + 'train': 'data/vqav2/vqav2_train.jsonl', + 'test': 'data/vqav2/vqav2_testdev.jsonl', + 'metric': None, + 'max_new_tokens': 10, + }, + 'okvqa_val': { + 'train': 'data/okvqa/okvqa_train.jsonl', + 'test': 'data/okvqa/okvqa_val.jsonl', + 'question': 'data/okvqa/OpenEnded_mscoco_val2014_questions.json', + 'annotation': 'data/okvqa/mscoco_val2014_annotations.json', + 'metric': 'vqa_score', + 'max_new_tokens': 10, + }, + 'textvqa_val': { + 'train': 'data/textvqa/textvqa_train.jsonl', + 'test': 'data/textvqa/textvqa_val.jsonl', + 'question': 'data/textvqa/textvqa_val_questions.json', + 'annotation': 'data/textvqa/textvqa_val_annotations.json', + 'metric': 'vqa_score', + 'max_new_tokens': 10, + }, + 'vizwiz_val': { + 'train': 'data/vizwiz/vizwiz_train.jsonl', + 'test': 'data/vizwiz/vizwiz_val.jsonl', + 'question': 'data/vizwiz/vizwiz_val_questions.json', + 'annotation': 'data/vizwiz/vizwiz_val_annotations.json', + 'metric': 'vqa_score', + 'max_new_tokens': 10, + }, + 'vizwiz_test': { + 'train': 'data/vizwiz/vizwiz_train.jsonl', + 'test': 'data/vizwiz/vizwiz_test.jsonl', + 'metric': None, + 'max_new_tokens': 10, + }, + 'docvqa_val': { + 'train': 'data/docvqa/train.jsonl', + 'test': 'data/docvqa/val.jsonl', + 'annotation': 'data/docvqa/val/val_v1.0.json', + 'metric': 'anls', + 'max_new_tokens': 100, + }, + 'docvqa_test': { + 'train': 'data/docvqa/train.jsonl', + 'test': 'data/docvqa/test.jsonl', + 'metric': None, + 'max_new_tokens': 100, + }, + 'chartqa_test_human': { + 'train': 'data/chartqa/train_human.jsonl', + 'test': 'data/chartqa/test_human.jsonl', + 'metric': 'relaxed_accuracy', + 'max_new_tokens': 100, + }, + 'chartqa_test_augmented': { + 'train': 'data/chartqa/train_augmented.jsonl', + 'test': 'data/chartqa/test_augmented.jsonl', + 'metric': 'relaxed_accuracy', + 'max_new_tokens': 100, + }, + 'gqa_testdev': { + 'train': 'data/gqa/train.jsonl', + 'test': 'data/gqa/testdev_balanced.jsonl', + 'metric': 'accuracy', + 'max_new_tokens': 10, + }, + 'ocrvqa_val': { + 'train': 'data/ocrvqa/ocrvqa_train.jsonl', + 'test': 'data/ocrvqa/ocrvqa_val.jsonl', + 'metric': 'accuracy', + 'max_new_tokens': 100, + }, + 'ocrvqa_test': { + 'train': 'data/ocrvqa/ocrvqa_train.jsonl', + 'test': 'data/ocrvqa/ocrvqa_test.jsonl', + 'metric': 'accuracy', + 'max_new_tokens': 100, + }, + 'ai2diagram_test': { + 'train': 'data/ai2diagram/train.jsonl', + 'test': 'data/ai2diagram/test.jsonl', + 'metric': 'accuracy', + 'max_new_tokens': 10, + } +} + +# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81 +def relaxed_correctness(target: str, + prediction: str, + max_relative_change: float = 0.05) -> bool: + """Calculates relaxed correctness. + + The correctness tolerates certain error ratio defined by max_relative_change. + See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: + “Following Methani et al. (2020), we use a relaxed accuracy measure for the + numeric answers to allow a minor inaccuracy that may result from the automatic + data extraction process. We consider an answer to be correct if it is within + 5% of the gold answer. For non-numeric answers, we still need an exact match + to consider an answer to be correct.” + + Args: + target: Target string. + prediction: Predicted string. + max_relative_change: Maximum relative change. + + Returns: + Whether the prediction was correct given the specified tolerance. + """ + + def _to_float(text: str) -> Optional[float]: + try: + if text.endswith('%'): + # Convert percentages to floats. + return float(text.rstrip('%')) / 100.0 + else: + return float(text) + except ValueError: + return None + + prediction_float = _to_float(prediction) + target_float = _to_float(target) + if prediction_float is not None and target_float: + relative_change = abs(prediction_float - + target_float) / abs(target_float) + return relative_change <= max_relative_change + else: + return prediction.lower() == target.lower() + + +def evaluate_relaxed_accuracy(entries): + scores = [] + for elem in entries: + if isinstance(elem['annotation'], str): + elem['annotation'] = [elem['annotation']] + score = max([ + relaxed_correctness(elem['answer'].strip(), ann) + for ann in elem['annotation'] + ]) + scores.append(score) + return sum(scores) / len(scores) + + +def evaluate_exact_match_accuracy(entries): + scores = [] + for elem in entries: + if isinstance(elem['annotation'], str): + elem['annotation'] = [elem['annotation']] + score = max([ + (1.0 if + (elem['answer'].strip().lower() == ann.strip().lower()) else 0.0) + for ann in elem['annotation'] + ]) + scores.append(score) + return sum(scores) / len(scores) + + +def collate_fn(batches, tokenizer): + + questions = [_['question'] for _ in batches] + question_ids = [_['question_id'] for _ in batches] + annotations = [_['annotation'] for _ in batches] + + input_ids = tokenizer(questions, return_tensors='pt', padding='longest') + + return question_ids, input_ids.input_ids, input_ids.attention_mask, annotations + + +class VQADataset(torch.utils.data.Dataset): + + def __init__(self, train, test, prompt, few_shot): + self.test = open(test).readlines() + self.prompt = prompt + + self.few_shot = few_shot + if few_shot > 0: + self.train = open(train).readlines() + + def __len__(self): + return len(self.test) + + def __getitem__(self, idx): + data = json.loads(self.test[idx].strip()) + image, question, question_id, annotation = data['image'], data[ + 'question'], data['question_id'], data.get('answer', None) + + few_shot_prompt = '' + if self.few_shot > 0: + few_shot_samples = random.sample(self.train, self.few_shot) + for sample in few_shot_samples: + sample = json.loads(sample.strip()) + few_shot_prompt += self.prompt.format( + sample['image'], + sample['question']) + f" {sample['answer']}" + + return { + 'question': few_shot_prompt + self.prompt.format(image, question), + 'question_id': question_id, + 'annotation': annotation + } + + +class InferenceSampler(torch.utils.data.sampler.Sampler): + + def __init__(self, size): + self._size = int(size) + assert size > 0 + self._rank = torch.distributed.get_rank() + self._world_size = torch.distributed.get_world_size() + self._local_indices = self._get_local_indices(size, self._world_size, + self._rank) + + @staticmethod + def _get_local_indices(total_size, world_size, rank): + shard_size = total_size // world_size + left = total_size % world_size + shard_sizes = [shard_size + int(r < left) for r in range(world_size)] + + begin = sum(shard_sizes[:rank]) + end = min(sum(shard_sizes[:rank + 1]), total_size) + return range(begin, end) + + def __iter__(self): + yield from self._local_indices + + def __len__(self): + return len(self._local_indices) + + +def textVQA_evaluation(model_name, dataset_name, dataset_path, tokenizer=None, batch_size=1, few_shot=0, seed=0): + torch.distributed.init_process_group( + backend='nccl', + world_size=int(os.getenv('WORLD_SIZE', '1')), + rank=int(os.getenv('RANK', '0')), + ) + + torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) + if isinstance(model_name, str): + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map='cuda', + trust_remote_code=True).eval() + tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True + ) + else: + assert tokenizer is not None, "Two types of parameter passing are supported:model_path or model with tokenizer." + model = model_name + + tokenizer.padding_side = 'left' + tokenizer.pad_token_id = tokenizer.eod_id + + prompt = '{}{} Answer:' + + random.seed(seed) + dataset = VQADataset( + train=os.path.join(dataset_path,ds_collections[dataset_name]['train']), + test=os.path.join(dataset_path,ds_collections[dataset_name]['test']), + prompt=prompt, + few_shot=few_shot, + ) + + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + # sampler=InferenceSampler(len(dataset)), + batch_size=batch_size, + # num_workers=0, + pin_memory=True, + drop_last=False, + collate_fn=partial(collate_fn, tokenizer=tokenizer), + ) + + outputs = [] + for _, (question_ids, input_ids, attention_mask, + annotations) in tqdm(enumerate(dataloader)): + pred = model.generate( + input_ids=input_ids.cuda(), + attention_mask=attention_mask.cuda(), + do_sample=False, + num_beams=1, + max_new_tokens=ds_collections[dataset_name]['max_new_tokens'], + min_new_tokens=1, + length_penalty=1, + num_return_sequences=1, + output_hidden_states=True, + use_cache=True, + pad_token_id=tokenizer.eod_id, + eos_token_id=tokenizer.eod_id, + ) + answers = [ + tokenizer.decode(_[input_ids.size(1):].cpu(), + skip_special_tokens=True).strip() for _ in pred + ] + + for question_id, answer, annotation in zip(question_ids, answers, + annotations): + if dataset in ['vqav2_val', 'vqav2_testdev', 'okvqa_val', 'textvqa_val', 'vizwiz_val']: + outputs.append({ + 'question_id': question_id, + 'answer': answer, + }) + elif dataset in ['docvqa_val', 'infographicsvqa', 'gqa_testdev', 'ocrvqa_val', 'ocrvqa_test']: + outputs.append({ + 'questionId': question_id, + 'answer': answer, + 'annotation': annotation, + }) + elif dataset in ['ai2diagram_test']: + outputs.append({ + 'image': question_id, + 'answer': answer, + 'annotation': annotation, + }) + elif dataset in ['chartqa_test_human', 'chartqa_test_augmented']: + outputs.append({ + 'answer': answer, + 'annotation': annotation, + }) + elif dataset in ['docvqa_test']: + outputs.append({ + 'questionId': question_id, + 'answer': answer, + }) + elif dataset in ['vizwiz_test']: + outputs.append({ + 'image': question_id, + 'answer': answer, + }) + else: + raise NotImplementedError + + torch.distributed.barrier() + + world_size = torch.distributed.get_world_size() + merged_outputs = [None for _ in range(world_size)] + torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) + + merged_outputs = [json.loads(_) for _ in merged_outputs] + merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] + + if torch.distributed.get_rank() == 0: + print(f"Evaluating {dataset} ...") + time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) + results_file = f'{dataset}_{time_prefix}_fs{few_shot}_s{seed}.json' + json.dump(merged_outputs, open(results_file, 'w'), ensure_ascii=False) + + if ds_collections[dataset_name]['metric'] == 'vqa_score': + vqa = VQA(os.path.join(dataset_path,ds_collections[dataset_name]['annotation']), + os.path.join(dataset_path,ds_collections[dataset_name]['question'])) + results = vqa.loadRes( + resFile=results_file, + quesFile=os.path.join(dataset_path,ds_collections[dataset_name]['question'])) + vqa_scorer = VQAEval(vqa, results, n=2) + vqa_scorer.evaluate() + + print(vqa_scorer.accuracy) + + elif ds_collections[dataset_name]['metric'] == 'anls': + json.dump(merged_outputs, + open(results_file, 'w'), + ensure_ascii=False) + print('python infographicsvqa_eval.py -g ' + + os.path.join(dataset_path,ds_collections[dataset_name]['annotation']) + ' -s ' + + results_file) + os.system('python infographicsvqa_eval.py -g ' + + os.path.join(dataset_path,ds_collections[dataset_name]['annotation']) + ' -s ' + + results_file) + elif ds_collections[dataset_name]['metric'] == 'relaxed_accuracy': + print({ + 'relaxed_accuracy': evaluate_relaxed_accuracy(merged_outputs) + }) + elif ds_collections[dataset_name]['metric'] == 'accuracy': + if 'gqa' in dataset: + for entry in merged_outputs: + response = entry['answer'] + response = response.strip().split('.')[0].split( + ',')[0].split('!')[0].lower() + if 'is ' in response: + response = response.split('is ')[1] + if 'are ' in response: + response = response.split('are ')[1] + if 'a ' in response: + response = response.split('a ')[1] + if 'an ' in response: + response = response.split('an ')[1] + if 'the ' in response: + response = response.split('the ')[1] + if ' of' in response: + response = response.split(' of')[0] + response = response.strip() + entry['answer'] = response + print({'accuracy': evaluate_exact_match_accuracy(merged_outputs)}) + + torch.distributed.barrier() \ No newline at end of file diff --git a/examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa.py b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa.py new file mode 100644 index 00000000..652807d1 --- /dev/null +++ b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa.py @@ -0,0 +1,206 @@ +"""Copyright (c) 2022, salesforce.com, inc. + +All rights reserved. +SPDX-License-Identifier: BSD-3-Clause +For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = 'aagrawal' +__version__ = '0.9' + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import copy +import datetime +import json + + +class VQA: + + def __init__(self, annotation_file=None, question_file=None): + """Constructor of VQA helper class for reading and visualizing + questions and answers. + + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + print('loading VQA annotations and questions into memory...') + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, 'r')) + questions = json.load(open(question_file, 'r')) + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + # create index + print('creating index...') + imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} + qa = {ann['question_id']: [] for ann in self.dataset['annotations']} + qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} + for ann in self.dataset['annotations']: + imgToQA[ann['image_id']] += [ann] + qa[ann['question_id']] = ann + for ques in self.questions['questions']: + qqa[ques['question_id']] = ques + print('index created!') + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """Print information about the VQA annotation file. + + :return: + """ + for key, value in self.datset['info'].items(): + print('%s: %s' % (key, value)) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """Get question ids that satisfy given filter conditions. default skips + that filter. + + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset['annotations'] + else: + if not len(imgIds) == 0: + anns = sum( + [ + self.imgToQA[imgId] + for imgId in imgIds if imgId in self.imgToQA + ], + [], + ) + else: + anns = self.dataset['annotations'] + anns = (anns if len(quesTypes) == 0 else + [ann for ann in anns if ann['question_type'] in quesTypes]) + anns = (anns if len(ansTypes) == 0 else + [ann for ann in anns if ann['answer_type'] in ansTypes]) + ids = [ann['question_id'] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """Get image ids that satisfy given filter conditions. default skips + that filter. + + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset['annotations'] + else: + if not len(quesIds) == 0: + anns = sum([ + self.qa[quesId] for quesId in quesIds if quesId in self.qa + ], []) + else: + anns = self.dataset['annotations'] + anns = (anns if len(quesTypes) == 0 else + [ann for ann in anns if ann['question_type'] in quesTypes]) + anns = (anns if len(ansTypes) == 0 else + [ann for ann in anns if ann['answer_type'] in ansTypes]) + ids = [ann['image_id'] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """Load questions and answers with the specified question ids. + + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """Display the specified annotations. + + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann['question_id'] + print('Question: %s' % (self.qqa[quesId]['question'])) + for ans in ann['answers']: + print('Answer %d: %s' % (ans['answer_id'], ans['answer'])) + + def loadRes(self, resFile, quesFile): + """Load result file and return a result object. + + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset['info'] = copy.deepcopy(self.questions['info']) + res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) + res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) + res.dataset['data_subtype'] = copy.deepcopy( + self.questions['data_subtype']) + res.dataset['license'] = copy.deepcopy(self.questions['license']) + + print('Loading and preparing results... ') + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, 'results is not an array of objects' + annsQuesIds = [ann['question_id'] for ann in anns] + assert set(annsQuesIds) == set( + self.getQuesIds() + ), 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' + for ann in anns: + quesId = ann['question_id'] + if res.dataset['task_type'] == 'Multiple Choice': + assert ( + ann['answer'] in self.qqa[quesId]['multiple_choices'] + ), 'predicted answer is not one of the multiple choices' + qaAnn = self.qa[quesId] + ann['image_id'] = qaAnn['image_id'] + ann['question_type'] = qaAnn['question_type'] + ann['answer_type'] = qaAnn['answer_type'] + print('DONE (t=%0.2fs)' % + ((datetime.datetime.utcnow() - time_t).total_seconds())) + + res.dataset['annotations'] = anns + res.createIndex() + return res \ No newline at end of file diff --git a/examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa_eval.py b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa_eval.py new file mode 100644 index 00000000..a44e90eb --- /dev/null +++ b/examples/multimodal-modeling/Qwen-VL/mm_evaluation/vqa_eval.py @@ -0,0 +1,330 @@ +"""Copyright (c) 2022, salesforce.com, inc. + +All rights reserved. +SPDX-License-Identifier: BSD-3-Clause +For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +# coding=utf-8 + +__author__ = 'aagrawal' + +import re +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys + + +class VQAEval: + + def __init__(self, vqa=None, vqaRes=None, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + if vqa is not None: + self.params = {'question_id': vqa.getQuesIds()} + self.contractions = { + 'aint': "ain't", + 'arent': "aren't", + 'cant': "can't", + 'couldve': "could've", + 'couldnt': "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + 'didnt': "didn't", + 'doesnt': "doesn't", + 'dont': "don't", + 'hadnt': "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + 'hasnt': "hasn't", + 'havent': "haven't", + 'hed': "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + 'hes': "he's", + 'howd': "how'd", + 'howll': "how'll", + 'hows': "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + 'Im': "I'm", + 'Ive': "I've", + 'isnt': "isn't", + 'itd': "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + 'itll': "it'll", + "let's": "let's", + 'maam': "ma'am", + 'mightnt': "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + 'mightve': "might've", + 'mustnt': "mustn't", + 'mustve': "must've", + 'neednt': "needn't", + 'notve': "not've", + 'oclock': "o'clock", + 'oughtnt': "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + 'shant': "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + 'shouldve': "should've", + 'shouldnt': "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": 'somebodyd', + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + 'somebodyll': "somebody'll", + 'somebodys': "somebody's", + 'someoned': "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + 'someonell': "someone'll", + 'someones': "someone's", + 'somethingd': "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + 'somethingll': "something'll", + 'thats': "that's", + 'thered': "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + 'therere': "there're", + 'theres': "there's", + 'theyd': "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + 'theyll': "they'll", + 'theyre': "they're", + 'theyve': "they've", + 'twas': "'twas", + 'wasnt': "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + 'weve': "we've", + 'werent': "weren't", + 'whatll': "what'll", + 'whatre': "what're", + 'whats': "what's", + 'whatve': "what've", + 'whens': "when's", + 'whered': "where'd", + 'wheres': "where's", + 'whereve': "where've", + 'whod': "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + 'wholl': "who'll", + 'whos': "who's", + 'whove': "who've", + 'whyll': "why'll", + 'whyre': "why're", + 'whys': "why's", + 'wont': "won't", + 'wouldve': "would've", + 'wouldnt': "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + 'yall': "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + 'youd': "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + 'youll': "you'll", + 'youre': "you're", + 'youve': "you've", + } + self.manualMap = { + 'none': '0', + 'zero': '0', + 'one': '1', + 'two': '2', + 'three': '3', + 'four': '4', + 'five': '5', + 'six': '6', + 'seven': '7', + 'eight': '8', + 'nine': '9', + 'ten': '10', + } + self.articles = ['a', 'an', 'the'] + + self.periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') + self.commaStrip = re.compile('(\d)(,)(\d)') + self.punct = [ + ';', + r'/', + '[', + ']', + '"', + '{', + '}', + '(', + ')', + '=', + '+', + '\\', + '_', + '-', + '>', + '<', + '@', + '`', + ',', + '?', + '!', + ] + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params['question_id']] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + print('computing accuracy') + step = 0 + for quesId in quesIds: + resAns = res[quesId]['answer'] + resAns = resAns.replace('\n', ' ') + resAns = resAns.replace('\t', ' ') + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]['answers']: + ansDic['answer'] = self.processPunctuation( + ansDic['answer']) + for gtAnsDatum in gts[quesId]['answers']: + otherGTAns = [ + item for item in gts[quesId]['answers'] + if item != gtAnsDatum + ] + matchingAns = [ + item for item in otherGTAns if item['answer'] == resAns + ] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + quesType = gts[quesId]['question_type'] + ansType = gts[quesId]['answer_type'] + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step % 100 == 0: + self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + print('Done computing accuracy') + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + ' ' in inText or ' ' + p + in inText) or (re.search(self.commaStrip, inText) != None): + outText = outText.replace(p, '') + else: + outText = outText.replace(p, ' ') + outText = self.periodStrip.sub('', outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = ' '.join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy['overall'] = round(100 * float(sum(accQA)) / len(accQA), + self.n) + self.accuracy['perQuestionType'] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / + len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + self.accuracy['perAnswerType'] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / + len(accAnsType[ansType]), self.n) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = '' + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = 'error: progress var must be float\r\n' + if progress < 0: + progress = 0 + status = 'Halt...\r\n' + if progress >= 1: + progress = 1 + status = 'Done...\r\n' + block = int(round(barLength * progress)) + text = '\rFinshed Percent: [{0}] {1}% {2}'.format( + '#' * block + '-' * (barLength - block), int(progress * 100), + status) + sys.stdout.write(text) + sys.stdout.flush() \ No newline at end of file diff --git a/examples/multimodal-modeling/requirements.txt b/examples/multimodal-modeling/requirements.txt new file mode 100644 index 00000000..9b0df5e0 --- /dev/null +++ b/examples/multimodal-modeling/requirements.txt @@ -0,0 +1,18 @@ +transformers +torch +git+https://github.com/EleutherAI/lm-evaluation-harness.git@96d185fa6232a5ab685ba7c43e45d1dbb3bb906d +# For the paper results use the old lm_eval (0.3.0) +# git+https://github.com/EleutherAI/lm-evaluation-harness.git@008fc2a23245c40384f2312718433eeb1e0f87a9 +tiktoken +transformers_stream_generator +peft +sentencepiece +einops +accelerate +datasets +protobuf +auto-gptq +openpyxl +wandb +py-cpuinfo + diff --git a/examples/multimodal-modeling/run_autoround.sh b/examples/multimodal-modeling/run_autoround.sh new file mode 100644 index 00000000..3a80b6bd --- /dev/null +++ b/examples/multimodal-modeling/run_autoround.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -x +device=0 + +CUDA_VISIBLE_DEVICES=$device \ +python3 main.py \ +--model_name=liuhaotian/llava-v1.5-7b \ +--bits 4 \ +--group_size 128 \ +--iters 200 \ +--deployment_device 'autoround' \ +--image_folder /path/to/coco/images/train2017/ \ +--question_file=self_made.json \ +--eval-path=/path/to/textvqa_data/ \ +--output_dir "./tmp_autoround" \ No newline at end of file