Skip to content

Commit

Permalink
Feature/phrase grounding recall filter (#139)
Browse files Browse the repository at this point in the history
* + add phrase_grounding_recall_filter op

* * update new OP
+ make llava_to_dj and dj_to_llava support only_keep_caption mode

* + Add unit test for phrase_grounding

* + remove hf model caches automatically for unittest

* * download required nltk data when initializing the phrase_grounding_recall_filter

* * output the cleaning log when the cleaning actually happens

* * update Operator docs

* * fix some typos

* * removing hf models automatically after unit test is finished for clip and blip
  • Loading branch information
HYLcool authored Jan 4, 2024
1 parent 0431f25 commit ad445c9
Show file tree
Hide file tree
Showing 15 changed files with 957 additions and 154 deletions.
59 changes: 35 additions & 24 deletions configs/config_all.yaml

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
image_shape_filter, image_size_filter,
image_text_matching_filter, image_text_similarity_filter,
language_id_score_filter, maximum_line_length_filter,
perplexity_filter, special_characters_filter,
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_action_filter,
text_entity_dependency_filter, text_length_filter,
token_num_filter, word_num_filter, word_repetition_filter)
perplexity_filter, phrase_grounding_recall_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
text_action_filter, text_entity_dependency_filter,
text_length_filter, token_num_filter, word_num_filter,
word_repetition_filter)

# yapf: enable
286 changes: 286 additions & 0 deletions data_juicer/ops/filter/phrase_grounding_recall_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
from typing import List

import numpy as np
from jsonargparse.typing import ClosedUnitInterval
from loguru import logger
from PIL import ImageOps

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import (SpecialTokens, iou, load_image,
remove_special_tokens)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'phrase_grounding_recall_filter'

with AvailabilityChecking(['torch', 'transformers', 'nltk'], OP_NAME):

import torch
import transformers # noqa: F401

# avoid hanging when calling clip in multiprocessing
torch.set_num_threads(1)

import nltk


# NER algorithm adapted from GLIP starts
# https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/engine/predictor_glip.py#L107-L127
def find_noun_phrases(caption: str) -> List[str]:
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
pos_tags = nltk.pos_tag(tokens)

grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}'
cp = nltk.RegexpParser(grammar)
result = cp.parse(pos_tags)

noun_phrases = list()
for subtree in result.subtrees():
if subtree.label() == 'NP':
noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))

return noun_phrases


def remove_punctuation(text: str) -> str:
punct = [
'|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’',
'`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
]
for p in punct:
text = text.replace(p, '')
return text.strip()


def run_ner(caption):
noun_phrases = find_noun_phrases(caption)
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
noun_phrases = list(set(noun_phrases)) # remove duplicate ners
return noun_phrases


# NER algorithm adapted from GLIP ends


@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class PhraseGroundingRecallFilter(Filter):
"""Filter to keep samples whose locating recalls of phrases extracted
from text in the images are within a specified range."""

def __init__(self,
hf_owlvit='google/owlvit-base-patch32',
min_recall: ClosedUnitInterval = 0.1,
max_recall: ClosedUnitInterval = 1.0,
horizontal_flip: bool = False,
vertical_flip: bool = False,
any_or_all: str = 'any',
reduce_mode: str = 'avg',
iou_thr: ClosedUnitInterval = 0.5,
large_area_ratio_thr: ClosedUnitInterval = 0.95,
conf_thr: ClosedUnitInterval = 0.0,
*args,
**kwargs):
"""
Initialization method.
:param hf_owlvit: Owl-ViT model name on huggingface to locate the
phrases extracted from the text.
:param min_recall: The min phrase grounding recall to keep samples.
:param max_recall: The max phrase grounding recall to keep samples.
:param horizontal_flip: Flip image horizontally (left to right).
:param vertical_flip: Flip image vertically (top to bottom).
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param reduce_mode: reduce mode when one text corresponds to
multiple images in a chunk.
'avg': Take the average of multiple values
'max': Take the max of multiple values
'min': Take the min of multiple values
:param iou_thr: the IoU threshold for NMS-like post-process. If two
predicted bboxes are overlap with an IoU larger than this
threshold, the bbox with less confidence will be removed. Default:
0.5.
:param large_area_ratio_thr: the area ratio threshold for filtering out
those large predicted bboxes. If the area of a predicted bbox
accounts for more than this ratio threshold of the whole image
area, this bbox will be removed. Default: 0.95.
:param conf_thr: the confidence score threshold for removing
low-confidence bboxes. If the confidence score of a predicted bbox
is lower than the threshold, this bbox will be removed. Default: 0.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_recall = min_recall
self.max_recall = max_recall
if reduce_mode not in ['avg', 'max', 'min']:
raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. '
f'Can only be one of ["avg", "max", "min"].')
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_type = 'hf_owlvit'
self.model_key = prepare_model(model_type=self.model_type,
model_key=hf_owlvit)
self.reduce_mode = reduce_mode
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

self.iou_thr = iou_thr
self.large_area_ratio_thr = large_area_ratio_thr
self.conf_thr = conf_thr

requires_nltk_data = ['punkt', 'averaged_perceptron_tagger']
logger.info(f'Downloading nltk data of {requires_nltk_data}...')
for nltk_data_pkg in requires_nltk_data:
nltk.download(nltk_data_pkg)

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.phrase_grounding_recall in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.phrase_grounding_recall] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if context and loaded_image_key in sample[Fields.context]:
# load from context
images[loaded_image_key] = sample[
Fields.context][loaded_image_key]
else:
if loaded_image_key not in images:
# avoid load the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image
if context:
# store the image data into context
sample[Fields.context][loaded_image_key] = image

text = sample[self.text_key]
offset = 0
recalls = []
model, processor = get_model(self.model_key,
model_type=self.model_type)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)

# no image or no text
if count == 0 or len(chunk) == 0:
continue
else:
text_this_chunk = remove_special_tokens(chunk)
ners_this_chunk = run_ner(text_this_chunk)
num_ners = len(ners_this_chunk)
if num_ners <= 0:
# no ners found, just skip this chunk
recalls.append(1.0)
continue
images_this_chunk = []
for image_key in loaded_image_keys[offset:offset + count]:
image = images[image_key]
if self.horizontal_flip:
image = ImageOps.mirror(image)
if self.vertical_flip:
image = ImageOps.flip(image)
images_this_chunk.append(image)

ners_batch = [ners_this_chunk] * len(images_this_chunk)
inputs = processor(text=ners_batch,
images=images_this_chunk,
return_tensors='pt',
padding=True,
truncation=True)

with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor(
[img.size[::-1] for img in images_this_chunk])
results = processor.post_process_object_detection(
outputs,
threshold=self.conf_thr,
target_sizes=target_sizes)

image_recalls = []
for idx, result in enumerate(results):
scores = result['scores']
labels = result['labels']
boxes = result['boxes']

# sort by the confidence scores
# and only keep the first num_ners predictions
order_idx = scores.argsort(descending=True)
scores = scores[order_idx].tolist()[:num_ners]
labels = labels[order_idx].tolist()[:num_ners]
boxes = boxes[order_idx].tolist()[:num_ners]

image_area = target_sizes[idx].prod()
hit = {}
for box, label, score in zip(boxes, labels, scores):
# this ner is already hit
if ners_this_chunk[label] in hit:
continue
# skip boxes nearly cover the whole image
xmin, ymin, xmax, ymax = box
box_area = (xmax - xmin) * (ymax - ymin)
if 1.0 * box_area / image_area > \
self.large_area_ratio_thr:
continue
# skip overlapped boxes with nms-like method
suppressed = False
for ner in hit:
if iou(box, hit[ner][0]) > self.iou_thr:
suppressed = True
break
if suppressed:
continue

# record the new hit box
hit[ners_this_chunk[label]] = (box, score)

recall = 1.0 * len(hit) / num_ners
image_recalls.append(recall)

if self.reduce_mode == 'avg':
image_recall = sum(image_recalls) / len(image_recalls)
elif self.reduce_mode == 'max':
image_recall = max(image_recalls)
else:
image_recall = min(image_recalls)

recalls.append(image_recall)
offset += count
sample[Fields.stats][StatsKeys.phrase_grounding_recall] = recalls

return sample

def process(self, sample):
recalls = sample[Fields.stats][StatsKeys.phrase_grounding_recall]
if len(recalls) <= 0:
return True

keep_bools = np.array([
self.min_recall <= recall <= self.max_recall for recall in recalls
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class StatsKeysConstant(object):
# multimodal
image_text_similarity = 'image_text_similarity'
image_text_matching_score = 'image_text_matching_score'
phrase_grounding_recall = 'phrase_grounding_recall'


class StatsKeys(object, metaclass=StatsKeysMeta):
Expand Down
39 changes: 28 additions & 11 deletions data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,21 @@ def get_special_tokens():

def remove_special_tokens(text):
for value in get_special_tokens().values():
text = text.replace(value, '')
text = text.replace(value, '').strip()
return text


# Image
def load_images(paths):
return [load_image(path) for path in paths]


def load_audios(paths):
return [load_audio(path) for path in paths]


def load_image(path):
img_feature = Image()
img = img_feature.decode_example(img_feature.encode_example(path))
return img


def load_audio(path, sampling_rate=None):
aud_feature = Audio(sampling_rate)
aud = aud_feature.decode_example(aud_feature.encode_example(path))
return (aud['array'], aud['sampling_rate'])


def pil_to_opencv(pil_image):
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
Expand All @@ -64,6 +55,32 @@ def get_image_size(path, ):
return os.path.getsize(path)


def iou(box1, box2):
x1_min, y1_min, x1_max, y1_max = box1
x2_min, y2_min, x2_max, y2_max = box2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
ix_min = max(x1_min, x2_min)
ix_max = min(x1_max, x2_max)
iy_min = max(y1_min, y2_min)
iy_max = min(y1_max, y2_max)
intersection = max(0, (ix_max - ix_min) * (iy_max - iy_min))
union = area1 + area2 - intersection
return 1.0 * intersection / union


# Audio
def load_audios(paths):
return [load_audio(path) for path in paths]


def load_audio(path, sampling_rate=None):
aud_feature = Audio(sampling_rate)
aud = aud_feature.decode_example(aud_feature.encode_example(path))
return aud['array'], aud['sampling_rate']


# Others
def size_to_bytes(size):
alphabets_list = [char for char in size if char.isalpha()]
numbers_list = [char for char in size if char.isdigit()]
Expand Down
Loading

0 comments on commit ad445c9

Please sign in to comment.