Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] add chartqa #1668

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
]

if WITH_MULTIMODAL:
from .chartqa import ChartQA
from .coco_caption import COCOCaption
from .coco_retrieval import COCORetrieval
from .coco_vqa import COCOVQA
Expand All @@ -54,5 +55,5 @@
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA'
'VSR', 'VizWiz', 'OCRVQA', 'ChartQA'
])
115 changes: 115 additions & 0 deletions mmpretrain/datasets/chartqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
from mmengine.dataset import BaseDataset
from mmengine.utils import is_abs

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class ChartQA(BaseDataset):
"""ChartQA dataset.

dataset:https://github.com/vis-nlp/ChartQA

folder structure:
data/chartqa
├── test
│ ├── png
│ ├── tables
│ ├── test_human.json
│ └── test_augmented.json
├── train
│ ├── png
│ ├── tables
│ ├── train_human.json
│ └── train_augmented.json
└── val
├── png
├── tables
├── val_human.json
└── val_augmented.json
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def _join_prefix(self):
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if not any(is_abs(sub_ann_file)
for sub_ann_file in self.ann_file) and self.ann_file:
self.ann_file = [
osp.join(self.data_root, sub_ann_file)
for sub_ann_file in self.ann_file
]
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.
for data_key, prefix in self.data_prefix.items():
if isinstance(prefix, str):
if not is_abs(prefix):
self.data_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
self.data_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string, but got '
f'{type(prefix)}')

def load_data_list(self) -> List[dict]:
"""Load data list."""
data_list = []

for sub_ann_file in self.ann_file:

annotations = mmengine.load(sub_ann_file)

for ann in annotations:

# ann example
# {
# 'imgname': '41699051005347.png'
# 'query': 'How many food item i...bar graph?',
# 'label': '14'
# }

data_info = dict(question=ann['query'])
data_info['image_id'] = ann['imgname']

img_path = mmengine.join_path(self.data_prefix['img_path'],
ann['imgname'])

data_info['img_path'] = img_path
data_info['gt_answer'] = ann['label']

if 'human' in sub_ann_file:
data_info['sub_set'] = 'ChartQA-H'
elif 'augmented' in sub_ann_file:
data_info['sub_set'] = 'ChartQA-M'
else:
raise ValueError(
f'Do not support to subset {sub_ann_file}.')

data_list.append(data_info)

return data_list
3 changes: 2 additions & 1 deletion mmpretrain/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .caption import COCOCaption
from .chartqa import ChartQARelaxACC
from .gqa import GQAAcc
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
Expand All @@ -16,5 +17,5 @@
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
'RetrievalAveragePrecision'
'RetrievalAveragePrecision', 'ChartQARelaxACC'
]
130 changes: 130 additions & 0 deletions mmpretrain/evaluation/metrics/chartqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import List, Optional

from mmengine.evaluator import BaseMetric

from mmpretrain.registry import METRICS
from .vqa import _process_digit_article, _process_punctuation


@METRICS.register_module()
class ChartQARelaxACC(BaseMetric):
'''ChartQARelaxACC.
Args:

collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
'''
default_prefix = 'ChartQARelaxACC'

def __init__(self,
full_score_weight: float = 0.3,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
relax_thresh: float = 0.05):
super().__init__(collect_device=collect_device, prefix=prefix)
self.full_score_weight = full_score_weight
self.relax_thresh = relax_thresh

def is_digit(self, x: str):
a = bool(re.match(r'^[+-]?\d+\.\d+$', x))
b = str(x).isnumeric()
return any([a, b])

def process(self, data_batch, data_samples):
"""Process one batch of data samples.

The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.

Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
gt_answer = sample.get('gt_answer')
sub_set = sample.get('sub_set')

is_digit = self.is_digit(gt_answer)

result = {
'pred_answer': sample.get('pred_answer'),
'gt_answer': gt_answer,
'is_digit': is_digit,
'sub_set': sub_set
}

self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.

Args:
results (dict): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
ChartQA_H_acc = []
ChartQA_M_acc = []
for result in results:
pred_answer = self._process_answer(result['pred_answer'])
gt_answer = result['gt_answer']
is_digit = result['is_digit']
sub_set = result['sub_set']

if is_digit:
if self.is_digit(pred_answer):
pred_answer = float(pred_answer)
gt_answer = float(gt_answer)
upper_bound = \
max(gt_answer - gt_answer * self.relax_thresh,
gt_answer + gt_answer * self.relax_thresh)
lower_bound = \
min(gt_answer - gt_answer * self.relax_thresh,
gt_answer + gt_answer * self.relax_thresh)
chart_acc = float(
all([
pred_answer >= lower_bound,
pred_answer <= upper_bound
]))
else:
chart_acc = 0.0
else:
chart_acc = float(pred_answer == gt_answer)

if sub_set == 'ChartQA-H':
ChartQA_H_acc.append(chart_acc)
elif sub_set == 'ChartQA-M':
ChartQA_M_acc.append(chart_acc)
else:
raise ValueError(f'Do not support to subset {sub_set}.')

ChartQA_H_acc = sum(ChartQA_H_acc) / len(ChartQA_H_acc) * 100
ChartQA_M_acc = sum(ChartQA_M_acc) / len(ChartQA_M_acc) * 100

accuracy = (ChartQA_H_acc + ChartQA_M_acc) / 2

metrics = {
'ChartQA-H acc': ChartQA_H_acc,
'ChartQA-M acc': ChartQA_M_acc,
'Overall acc': accuracy
}

return metrics

def _process_answer(self, answer):
answer = answer.replace('\n', ' ')
answer = answer.replace('\t', ' ')
answer = answer.strip()
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer