From 5e780f8f794ddeeeec4e27c96fa3f49d17863bfc Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Mon, 26 Jun 2023 21:58:20 +0800 Subject: [PATCH 1/2] add chartqa --- a.py | 33 +++++ configs/flamingo/flamingo_zeroshot_ChartQA.py | 83 ++++++++++++ mmpretrain/datasets/__init__.py | 4 +- mmpretrain/datasets/chartqa.py | 115 +++++++++++++++++ mmpretrain/evaluation/metrics/__init__.py | 3 +- mmpretrain/evaluation/metrics/chartqa.py | 120 ++++++++++++++++++ 6 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 a.py create mode 100644 configs/flamingo/flamingo_zeroshot_ChartQA.py create mode 100644 mmpretrain/datasets/chartqa.py create mode 100644 mmpretrain/evaluation/metrics/chartqa.py diff --git a/a.py b/a.py new file mode 100644 index 00000000000..68f18f92a01 --- /dev/null +++ b/a.py @@ -0,0 +1,33 @@ +from mmpretrain.datasets import ChartQA + + +test_pipeline = [ + dict(type='mmpretrain.LoadImageFromFile'), + dict( + type='mmpretrain.ResizeEdge', + scale=224, + interpolation='bicubic', + backend='pillow'), + dict(type='mmpretrain.CenterCrop', crop_size=(224, 224)), + dict( + type='mmpretrain.PackInputs', + algorithm_keys=['question', 'gt_answer', 'sub_set'], + meta_keys=['image_id'], + ), +] + + +dataset = ChartQA( + data_root='data/chartqa/test', + data_prefix='png', + ann_file=['test_human.json', 'test_augmented.json'], + pipeline=test_pipeline) + +# dataset = ChartQA( +# data_root='data/chartqa/train', +# data_prefix='png', +# ann_file=['train_human.json', ], +# pipeline=test_pipeline) + + +print("a") \ No newline at end of file diff --git a/configs/flamingo/flamingo_zeroshot_ChartQA.py b/configs/flamingo/flamingo_zeroshot_ChartQA.py new file mode 100644 index 00000000000..55f7be724aa --- /dev/null +++ b/configs/flamingo/flamingo_zeroshot_ChartQA.py @@ -0,0 +1,83 @@ +_base_ = [ + '../_base_/default_runtime.py', +] + +zeroshot_prompt = ( + 'Question:In which year the value was 51? Short Answer:2014<|endofchunk|>' # noqa: E501 + 'Question:Is the value of Favorable 38 in 2015? Short Answer:Yes<|endofchunk|>' # noqa: E501 +) + +# model settings +model = dict( + type='Flamingo', + tokenizer=dict( + type='LlamaTokenizer', name_or_path='decapoda-research/llama-7b-hf'), + vision_encoder=dict( + type='VisionTransformer', + arch='l', + patch_size=14, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), + final_norm=False, + out_type='raw', + pretrained= + '/mnt/petrelfs/zhaowangbo/openmmlab/vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth', + ), + lang_encoder=dict( + base=dict( + type='AutoModelForCausalLM', + name_or_path='decapoda-research/llama-7b-hf', + local_files_only=True), + adapter=dict( + type='FlamingoLMAdapter', + vis_hidden_size=1024, + cross_attn_every_n_layers=4, + use_media_placement_augmentation=False), + ), + task='vqa', + zeroshot_prompt=zeroshot_prompt, + final_prompt_tmpl='Question:{question} Short Answer:', + generation_cfg=dict(num_beams=3, max_new_tokens=5, length_penalty=-2.0)) + +# data settings +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=224, + interpolation='bicubic', + backend='pillow'), + dict(type='CenterCrop', crop_size=(224, 224)), + dict( + type='PackInputs', + algorithm_keys=['question', 'gt_answer', 'sub_set'], + meta_keys=['image_id'], + ), +] + +test_dataloader = dict( + batch_size=1, + num_workers=8, + dataset=dict( + type='ChartQA', + data_root='data/chartqa/test', + data_prefix='png', + ann_file=['test_human.json', 'test_augmented.json'], + pipeline=test_pipeline + ), + sampler=dict(type='DefaultSampler', shuffle=False), + persistent_workers=True, +) + +test_evaluator = dict(type='ChartQARelaxACC') + +# schedule settings +test_cfg = dict() diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index b7b6be47dce..05321c3f637 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -49,10 +49,12 @@ from .visual_genome import VisualGenomeQA from .vizwiz import VizWiz from .vsr import VSR + from .chartqa import ChartQA + __all__.extend([ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', - 'VSR', 'VizWiz', 'OCRVQA' + 'VSR', 'VizWiz', 'OCRVQA', 'ChartQA' ]) diff --git a/mmpretrain/datasets/chartqa.py b/mmpretrain/datasets/chartqa.py new file mode 100644 index 00000000000..a3736b44c4d --- /dev/null +++ b/mmpretrain/datasets/chartqa.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import Counter +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset + +from mmpretrain.registry import DATASETS +from mmengine.utils import is_abs +import os.path as osp + + +@DATASETS.register_module() +class ChartQA(BaseDataset): + """ChartQA dataset. + + dataset:https://github.com/vis-nlp/ChartQA + + folder structure: + data/chartqa + ├── test + │ ├── png + │ ├── tables + │ ├── test_human.json + │ └── test_augmented.json + ├── train + │ ├── png + │ ├── tables + │ ├── train_human.json + │ └── train_augmented.json + └── val + ├── png + ├── tables + ├── val_human.json + └── val_augmented.json + Args: + data_root (str): The root directory for ``data_prefix``, ``ann_file`` + and ``question_file``. + data_prefix (str): The directory of images. + question_file (str): Question file path. + ann_file (str, optional): Annotation file path for training and + validation. Defaults to an empty string. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def __init__(self, + data_root: str, + data_prefix: str, + ann_file: str = '', + **kwarg): + super().__init__( + data_root=data_root, + data_prefix=dict(img_path=data_prefix), + ann_file=ann_file, + **kwarg, + ) + + def _join_prefix(self): + # Automatically join annotation file path with `self.root` if + # `self.ann_file` is not an absolute path. + if not any(is_abs(sub_ann_file) for sub_ann_file in self.ann_file) and self.ann_file: + self.ann_file = [osp.join(self.data_root, sub_ann_file) for sub_ann_file in self.ann_file] + # Automatically join data directory with `self.root` if path value in + # `self.data_prefix` is not an absolute path. + for data_key, prefix in self.data_prefix.items(): + if isinstance(prefix, str): + if not is_abs(prefix): + self.data_prefix[data_key] = osp.join( + self.data_root, prefix) + else: + self.data_prefix[data_key] = prefix + else: + raise TypeError('prefix should be a string, but got ' + f'{type(prefix)}') + + def load_data_list(self) -> List[dict]: + """Load data list.""" + data_list = [] + + for sub_ann_file in self.ann_file: + + annotations = mmengine.load(sub_ann_file) + + + + for ann in annotations: + + # ann example + # { + # 'imgname': '41699051005347.png' + # 'query': 'How many food item i...bar graph?', + # 'label': '14' + # } + + data_info = dict(question=ann['query']) + data_info['image_id'] = ann['imgname'] + + img_path = mmengine.join_path(self.data_prefix['img_path'], + ann['imgname']) + + data_info['img_path'] = img_path + data_info['gt_answer'] = ann['label'] + + if 'human' in sub_ann_file: + data_info['sub_set'] = 'ChartQA-H' + elif 'augmented' in sub_ann_file: + data_info['sub_set'] = 'ChartQA-M' + else: + raise ValueError(f'Do not support to subset {sub_ann_file}.') + + data_list.append(data_info) + + return data_list + + diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index 7f5a4f36b41..f141ddf77ad 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -10,11 +10,12 @@ from .visual_grounding_eval import VisualGroundingMetric from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric from .vqa import ReportVQA, VQAAcc +from .chartqa import ChartQARelaxACC __all__ = [ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', 'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric', 'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption', 'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave', - 'RetrievalAveragePrecision' + 'RetrievalAveragePrecision', 'ChartQARelaxACC' ] diff --git a/mmpretrain/evaluation/metrics/chartqa.py b/mmpretrain/evaluation/metrics/chartqa.py new file mode 100644 index 00000000000..d722dc21635 --- /dev/null +++ b/mmpretrain/evaluation/metrics/chartqa.py @@ -0,0 +1,120 @@ +from typing import List, Optional + +import mmengine +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger + +from mmpretrain.registry import METRICS + +import re +from .vqa import _process_punctuation, _process_digit_article +@METRICS.register_module() +class ChartQARelaxACC(BaseMetric): + '''ChartQARelaxACC. + Args: + + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Should be modified according to the + `retrieval_type` for unambiguous results. Defaults to TR. + ''' + default_prefix = 'ChartQARelaxACC' + + def __init__(self, + full_score_weight: float = 0.3, + collect_device: str = 'cpu', + prefix: Optional[str] = None, + relax_thresh: float = 0.05 ): + super().__init__(collect_device=collect_device, prefix=prefix) + self.full_score_weight = full_score_weight + self.relax_thresh = relax_thresh + + def is_digit(self, x: str): + a = bool(re.match(r'^[+-]?\d+\.\d+$', x)) + b = str(x).isnumeric() + return any([a, b]) + + + def process(self, data_batch, data_samples): + """Process one batch of data samples. + + The processed results should be stored in ``self.results``, which will + be used to computed the metrics when all batches have been processed. + + Args: + data_batch: A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for sample in data_samples: + gt_answer = sample.get('gt_answer') + sub_set = sample.get('sub_set') + + is_digit = self.is_digit(gt_answer) + + result = { + 'pred_answer': sample.get('pred_answer'), + 'gt_answer': gt_answer, + 'is_digit': is_digit, + 'sub_set': sub_set + } + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + ChartQA_H_acc = [] + ChartQA_M_acc = [] + for result in results: + pred_answer = self._process_answer(result['pred_answer']) + gt_answer = result['gt_answer'] + is_digit = result['is_digit'] + sub_set = result['sub_set'] + + if is_digit: + if self.is_digit(pred_answer): + pred_answer = float(pred_answer) + gt_answer = float(gt_answer) + upper_bound = max(gt_answer - gt_answer * self.relax_thresh, gt_answer + gt_answer * self.relax_thresh) + lower_bound = min(gt_answer - gt_answer * self.relax_thresh, gt_answer + gt_answer * self.relax_thresh) + chart_acc = float(all([pred_answer >= lower_bound, pred_answer <= upper_bound])) + else: + chart_acc = 0.0 + else: + chart_acc = float(pred_answer == gt_answer) + + if sub_set == 'ChartQA-H': + ChartQA_H_acc.append(chart_acc) + elif sub_set == 'ChartQA-M': + ChartQA_M_acc.append(chart_acc) + else: + raise ValueError(f'Do not support to subset {sub_set}.') + + + ChartQA_H_acc = sum(ChartQA_H_acc) / len(ChartQA_H_acc) * 100 + ChartQA_M_acc = sum(ChartQA_M_acc) / len(ChartQA_M_acc) * 100 + + accuracy = (ChartQA_H_acc + ChartQA_M_acc) / 2 + + metrics = {'ChartQA-H acc': ChartQA_H_acc, 'ChartQA-M acc': ChartQA_M_acc, 'Overall acc': accuracy} + + return metrics + + def _process_answer(self, answer): + answer = answer.replace('\n', ' ') + answer = answer.replace('\t', ' ') + answer = answer.strip() + answer = _process_punctuation(answer) + answer = _process_digit_article(answer) + return answer \ No newline at end of file From 0cdbaca56edd9f9f934bde2152c59094b284f4f6 Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Tue, 27 Jun 2023 11:24:31 +0800 Subject: [PATCH 2/2] pre-commit --- a.py | 33 -------- configs/flamingo/flamingo_zeroshot_ChartQA.py | 83 ------------------- mmpretrain/datasets/__init__.py | 3 +- mmpretrain/datasets/chartqa.py | 40 ++++----- mmpretrain/evaluation/metrics/__init__.py | 2 +- mmpretrain/evaluation/metrics/chartqa.py | 50 ++++++----- 6 files changed, 52 insertions(+), 159 deletions(-) delete mode 100644 a.py delete mode 100644 configs/flamingo/flamingo_zeroshot_ChartQA.py diff --git a/a.py b/a.py deleted file mode 100644 index 68f18f92a01..00000000000 --- a/a.py +++ /dev/null @@ -1,33 +0,0 @@ -from mmpretrain.datasets import ChartQA - - -test_pipeline = [ - dict(type='mmpretrain.LoadImageFromFile'), - dict( - type='mmpretrain.ResizeEdge', - scale=224, - interpolation='bicubic', - backend='pillow'), - dict(type='mmpretrain.CenterCrop', crop_size=(224, 224)), - dict( - type='mmpretrain.PackInputs', - algorithm_keys=['question', 'gt_answer', 'sub_set'], - meta_keys=['image_id'], - ), -] - - -dataset = ChartQA( - data_root='data/chartqa/test', - data_prefix='png', - ann_file=['test_human.json', 'test_augmented.json'], - pipeline=test_pipeline) - -# dataset = ChartQA( -# data_root='data/chartqa/train', -# data_prefix='png', -# ann_file=['train_human.json', ], -# pipeline=test_pipeline) - - -print("a") \ No newline at end of file diff --git a/configs/flamingo/flamingo_zeroshot_ChartQA.py b/configs/flamingo/flamingo_zeroshot_ChartQA.py deleted file mode 100644 index 55f7be724aa..00000000000 --- a/configs/flamingo/flamingo_zeroshot_ChartQA.py +++ /dev/null @@ -1,83 +0,0 @@ -_base_ = [ - '../_base_/default_runtime.py', -] - -zeroshot_prompt = ( - 'Question:In which year the value was 51? Short Answer:2014<|endofchunk|>' # noqa: E501 - 'Question:Is the value of Favorable 38 in 2015? Short Answer:Yes<|endofchunk|>' # noqa: E501 -) - -# model settings -model = dict( - type='Flamingo', - tokenizer=dict( - type='LlamaTokenizer', name_or_path='decapoda-research/llama-7b-hf'), - vision_encoder=dict( - type='VisionTransformer', - arch='l', - patch_size=14, - pre_norm=True, - norm_cfg=dict(type='LN', eps=1e-5), - layer_cfgs=dict(act_cfg=dict(type='QuickGELU')), - final_norm=False, - out_type='raw', - pretrained= - '/mnt/petrelfs/zhaowangbo/openmmlab/vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth', - ), - lang_encoder=dict( - base=dict( - type='AutoModelForCausalLM', - name_or_path='decapoda-research/llama-7b-hf', - local_files_only=True), - adapter=dict( - type='FlamingoLMAdapter', - vis_hidden_size=1024, - cross_attn_every_n_layers=4, - use_media_placement_augmentation=False), - ), - task='vqa', - zeroshot_prompt=zeroshot_prompt, - final_prompt_tmpl='Question:{question} Short Answer:', - generation_cfg=dict(num_beams=3, max_new_tokens=5, length_penalty=-2.0)) - -# data settings -data_preprocessor = dict( - type='MultiModalDataPreprocessor', - mean=[122.770938, 116.7460125, 104.09373615], - std=[68.5005327, 66.6321579, 70.32316305], - to_rgb=True, -) - -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='ResizeEdge', - scale=224, - interpolation='bicubic', - backend='pillow'), - dict(type='CenterCrop', crop_size=(224, 224)), - dict( - type='PackInputs', - algorithm_keys=['question', 'gt_answer', 'sub_set'], - meta_keys=['image_id'], - ), -] - -test_dataloader = dict( - batch_size=1, - num_workers=8, - dataset=dict( - type='ChartQA', - data_root='data/chartqa/test', - data_prefix='png', - ann_file=['test_human.json', 'test_augmented.json'], - pipeline=test_pipeline - ), - sampler=dict(type='DefaultSampler', shuffle=False), - persistent_workers=True, -) - -test_evaluator = dict(type='ChartQARelaxACC') - -# schedule settings -test_cfg = dict() diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index 05321c3f637..5bb5b230002 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -34,6 +34,7 @@ ] if WITH_MULTIMODAL: + from .chartqa import ChartQA from .coco_caption import COCOCaption from .coco_retrieval import COCORetrieval from .coco_vqa import COCOVQA @@ -49,8 +50,6 @@ from .visual_genome import VisualGenomeQA from .vizwiz import VizWiz from .vsr import VSR - from .chartqa import ChartQA - __all__.extend([ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', diff --git a/mmpretrain/datasets/chartqa.py b/mmpretrain/datasets/chartqa.py index a3736b44c4d..180eaa78970 100644 --- a/mmpretrain/datasets/chartqa.py +++ b/mmpretrain/datasets/chartqa.py @@ -1,13 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from collections import Counter +import os.path as osp from typing import List import mmengine from mmengine.dataset import BaseDataset +from mmengine.utils import is_abs from mmpretrain.registry import DATASETS -from mmengine.utils import is_abs -import os.path as osp @DATASETS.register_module() @@ -53,13 +52,17 @@ def __init__(self, data_prefix=dict(img_path=data_prefix), ann_file=ann_file, **kwarg, - ) - + ) + def _join_prefix(self): # Automatically join annotation file path with `self.root` if # `self.ann_file` is not an absolute path. - if not any(is_abs(sub_ann_file) for sub_ann_file in self.ann_file) and self.ann_file: - self.ann_file = [osp.join(self.data_root, sub_ann_file) for sub_ann_file in self.ann_file] + if not any(is_abs(sub_ann_file) + for sub_ann_file in self.ann_file) and self.ann_file: + self.ann_file = [ + osp.join(self.data_root, sub_ann_file) + for sub_ann_file in self.ann_file + ] # Automatically join data directory with `self.root` if path value in # `self.data_prefix` is not an absolute path. for data_key, prefix in self.data_prefix.items(): @@ -72,16 +75,14 @@ def _join_prefix(self): else: raise TypeError('prefix should be a string, but got ' f'{type(prefix)}') - + def load_data_list(self) -> List[dict]: """Load data list.""" data_list = [] - + for sub_ann_file in self.ann_file: - - annotations = mmengine.load(sub_ann_file) - + annotations = mmengine.load(sub_ann_file) for ann in annotations: @@ -91,25 +92,24 @@ def load_data_list(self) -> List[dict]: # 'query': 'How many food item i...bar graph?', # 'label': '14' # } - + data_info = dict(question=ann['query']) data_info['image_id'] = ann['imgname'] - + img_path = mmengine.join_path(self.data_prefix['img_path'], - ann['imgname']) - + ann['imgname']) + data_info['img_path'] = img_path data_info['gt_answer'] = ann['label'] - + if 'human' in sub_ann_file: data_info['sub_set'] = 'ChartQA-H' elif 'augmented' in sub_ann_file: data_info['sub_set'] = 'ChartQA-M' else: - raise ValueError(f'Do not support to subset {sub_ann_file}.') + raise ValueError( + f'Do not support to subset {sub_ann_file}.') data_list.append(data_info) return data_list - - diff --git a/mmpretrain/evaluation/metrics/__init__.py b/mmpretrain/evaluation/metrics/__init__.py index f141ddf77ad..e0dee70f761 100644 --- a/mmpretrain/evaluation/metrics/__init__.py +++ b/mmpretrain/evaluation/metrics/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .caption import COCOCaption +from .chartqa import ChartQARelaxACC from .gqa import GQAAcc from .multi_label import AveragePrecision, MultiLabelMetric from .multi_task import MultiTasksMetric @@ -10,7 +11,6 @@ from .visual_grounding_eval import VisualGroundingMetric from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric from .vqa import ReportVQA, VQAAcc -from .chartqa import ChartQARelaxACC __all__ = [ 'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision', diff --git a/mmpretrain/evaluation/metrics/chartqa.py b/mmpretrain/evaluation/metrics/chartqa.py index d722dc21635..c3294499b38 100644 --- a/mmpretrain/evaluation/metrics/chartqa.py +++ b/mmpretrain/evaluation/metrics/chartqa.py @@ -1,13 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re from typing import List, Optional -import mmengine from mmengine.evaluator import BaseMetric -from mmengine.logging import MMLogger from mmpretrain.registry import METRICS +from .vqa import _process_digit_article, _process_punctuation + -import re -from .vqa import _process_punctuation, _process_digit_article @METRICS.register_module() class ChartQARelaxACC(BaseMetric): '''ChartQARelaxACC. @@ -28,16 +28,15 @@ def __init__(self, full_score_weight: float = 0.3, collect_device: str = 'cpu', prefix: Optional[str] = None, - relax_thresh: float = 0.05 ): + relax_thresh: float = 0.05): super().__init__(collect_device=collect_device, prefix=prefix) self.full_score_weight = full_score_weight self.relax_thresh = relax_thresh - def is_digit(self, x: str): + def is_digit(self, x: str): a = bool(re.match(r'^[+-]?\d+\.\d+$', x)) b = str(x).isnumeric() return any([a, b]) - def process(self, data_batch, data_samples): """Process one batch of data samples. @@ -54,7 +53,7 @@ def process(self, data_batch, data_samples): sub_set = sample.get('sub_set') is_digit = self.is_digit(gt_answer) - + result = { 'pred_answer': sample.get('pred_answer'), 'gt_answer': gt_answer, @@ -81,40 +80,51 @@ def compute_metrics(self, results: List): gt_answer = result['gt_answer'] is_digit = result['is_digit'] sub_set = result['sub_set'] - + if is_digit: if self.is_digit(pred_answer): pred_answer = float(pred_answer) gt_answer = float(gt_answer) - upper_bound = max(gt_answer - gt_answer * self.relax_thresh, gt_answer + gt_answer * self.relax_thresh) - lower_bound = min(gt_answer - gt_answer * self.relax_thresh, gt_answer + gt_answer * self.relax_thresh) - chart_acc = float(all([pred_answer >= lower_bound, pred_answer <= upper_bound])) + upper_bound = \ + max(gt_answer - gt_answer * self.relax_thresh, + gt_answer + gt_answer * self.relax_thresh) + lower_bound = \ + min(gt_answer - gt_answer * self.relax_thresh, + gt_answer + gt_answer * self.relax_thresh) + chart_acc = float( + all([ + pred_answer >= lower_bound, + pred_answer <= upper_bound + ])) else: chart_acc = 0.0 else: chart_acc = float(pred_answer == gt_answer) - + if sub_set == 'ChartQA-H': ChartQA_H_acc.append(chart_acc) elif sub_set == 'ChartQA-M': ChartQA_M_acc.append(chart_acc) else: raise ValueError(f'Do not support to subset {sub_set}.') - - + ChartQA_H_acc = sum(ChartQA_H_acc) / len(ChartQA_H_acc) * 100 ChartQA_M_acc = sum(ChartQA_M_acc) / len(ChartQA_M_acc) * 100 - + accuracy = (ChartQA_H_acc + ChartQA_M_acc) / 2 - metrics = {'ChartQA-H acc': ChartQA_H_acc, 'ChartQA-M acc': ChartQA_M_acc, 'Overall acc': accuracy} - + metrics = { + 'ChartQA-H acc': ChartQA_H_acc, + 'ChartQA-M acc': ChartQA_M_acc, + 'Overall acc': accuracy + } + return metrics - + def _process_answer(self, answer): answer = answer.replace('\n', ' ') answer = answer.replace('\t', ' ') answer = answer.strip() answer = _process_punctuation(answer) answer = _process_digit_article(answer) - return answer \ No newline at end of file + return answer