Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/clip similarity filter #69

Merged
merged 17 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ process:
rep_len: 10 # repetition length for char-level n-gram
min_ratio: 0.0 # the min ratio of filter range
max_ratio: 0.5 # the max ratio of filter range
- clip_similarity_filter: # filter samples according to the similarity between text and images.
hf_clip: openai/clip-vit-base-patch32 # name of used Hugging Face clip
min_score: 0.1 # the min similarity of filter range
max_score: 1.0 # the max similarity of filter range
reduce_mode: avg # reduce mode when one text corresponds to multiple images in a chunk, must be one of ['avg','max', 'min'].
any_or_all: any # keep this sample when any/all images meet the filter condition
- flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value
lang: en # consider flagged words in what language
tokenization: false # whether to use model to tokenize documents
Expand Down
14 changes: 7 additions & 7 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from . import (alphanumeric_filter, average_line_length_filter,
character_repetition_filter, flagged_words_filter,
image_aspect_ratio_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
text_length_filter, token_num_filter, word_num_filter,
word_repetition_filter)
character_repetition_filter, clip_similarity_filter,
flagged_words_filter, image_aspect_ratio_filter,
language_id_score_filter, maximum_line_length_filter,
perplexity_filter, special_characters_filter,
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_length_filter,
token_num_filter, word_num_filter, word_repetition_filter)
158 changes: 158 additions & 0 deletions data_juicer/ops/filter/clip_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import numpy as np
import torch
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import SpecialTokens, load_image
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES

# avoid hanging when calling clip in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module('clip_similarity_filter')
@LOADED_IMAGES.register_module('clip_similarity_filter')
class ClipSimilarityFilter(Filter):
"""Filter to keep samples those similarity between image and text
within a specific range."""

def __init__(self,
hf_clip='openai/clip-vit-base-patch32',
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
any_or_all: str = 'any',
reduce_mode: str = 'avg',
*args,
**kwargs):
"""
Initialization method.

:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param reduce_mode: reduce mode when one text corresponds to
multiple images in a chunk.
'avg': Take the average of multiple values
'max': Take the max of multiple values
'min': Take the min of multiple values
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
if reduce_mode not in ['avg', 'max', 'min']:
raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. '
f'Can only be one of ["avg", "max", "min"].')
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='hf_clip', model_key=hf_clip)
self.reduce_mode = reduce_mode

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.clip_image_text_similarity in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][
StatsKeys.clip_image_text_similarity] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
images = {}
for loaded_image_key in loaded_image_keys:
if context and loaded_image_key in sample[Fields.context]:
# load from context
images[loaded_image_key] = sample[
Fields.context][loaded_image_key]
else:
if loaded_image_key not in images:
# avoid load the same images
image = load_image(loaded_image_key)
images[loaded_image_key] = image
if context:
# store the image data into context
sample[Fields.context][loaded_image_key] = image

text = sample[self.text_key]
special_token_dict = {
key: value
for key, value in SpecialTokens.__dict__.items()
if not key.startswith('__')
}
offset = 0

def remove_special_token(text):
for value in special_token_dict.values():
text = text.replace(value, '')
return text

similarity = []
model, processor = get_model(self.model_key)

for chunk in text.split(SpecialTokens.eoc):
count = chunk.count(SpecialTokens.image)

# no image or no text
if count == 0 or len(chunk) == 0:
continue
else:
text_chunk = remove_special_token(chunk)
image_chunk = [
images[image_key]
for image_key in loaded_image_keys[offset:offset + count]
]

inputs = processor(text=text_chunk,
images=image_chunk,
return_tensors='pt',
truncation=True,
max_length=model.config.text_config.
max_position_embeddings,
padding=True)

outputs = model(**inputs)
chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0

if self.reduce_mode == 'avg':
chunk_similarity = chunk_logits.mean()
elif self.reduce_mode == 'max':
chunk_similarity = chunk_logits.max()
else:
chunk_similarity = chunk_logits.min()

similarity.append(float(chunk_similarity))
offset += count
sample[Fields.stats][StatsKeys.clip_image_text_similarity] = similarity

return sample

def process(self, sample):
similarity = sample[Fields.stats][StatsKeys.clip_image_text_similarity]
if len(similarity) <= 0:
return True

keep_bools = np.array([
self.min_score <= sim_value <= self.max_score
for sim_value in similarity
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
3 changes: 3 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class StatsKeys(object):
# image
aspect_ratios = 'aspect_ratios'

# multimodal
clip_image_text_similarity = 'clip_image_text_similarity'


class HashKeys(object):
hash = DEFAULT_PREFIX + 'hash'
Expand Down
22 changes: 22 additions & 0 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ def prepare_huggingface_tokenizer(tokenizer_name):
return tokenizer


def prepare_huggingface_clip(clip_name):
"""
Prepare and load a clip and processor from HuggingFace.

:param clip_name: input clip name
:return: a pair of clip instance and processor instance.
"""
from transformers import CLIPModel, CLIPProcessor

model = CLIPModel.from_pretrained(clip_name)
processor = CLIPProcessor.from_pretrained(clip_name)
logger.info('Loading clip and processor from HuggingFace...')

return (model, processor)


def prepare_diversity_model(model_name, lang):
"""
Prepare diversity model for specific language.
Expand Down Expand Up @@ -222,6 +238,7 @@ def prepare_model(lang='en', model_type='sentencepiece', model_key=None):
'kenlm': ('%s.arpa.bin', prepare_kenlm_model),
'nltk': ('punkt.%s.pickle', prepare_nltk_model),
'huggingface': ('%s', prepare_huggingface_tokenizer),
'hf_clip': ('%s', prepare_huggingface_clip),
'spacy': ('%s_core_web_md-3.5.0', prepare_diversity_model),
}
assert model_type in type_to_name.keys(
Expand All @@ -236,6 +253,11 @@ def prepare_model(lang='en', model_type='sentencepiece', model_key=None):
MODEL_ZOO[model_key] = model_func(model_name)
elif model_type == 'huggingface':
MODEL_ZOO[model_key] = model_func(model_key)
elif model_type == 'hf_clip':
new_model_key = model_type + model_key
if new_model_key not in MODEL_ZOO.keys():
MODEL_ZOO[new_model_key] = model_func(model_key)
model_key = new_model_key
else:
MODEL_ZOO[model_key] = model_func(model_name, lang)
return model_key
Expand Down
3 changes: 2 additions & 1 deletion demos/overview_scan/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
|-----------------------------------|:------:|-------------------------------------------------|
| Formatter | 7 | Discovers, loads, and canonicalizes source data |
| Mapper | 21 | Edits and transforms samples |
| Filter | 17 | Filters out low-quality samples |
| Filter | 18 | Filters out low-quality samples |
| Deduplicator | 3 | Detects and removes duplicate samples |
| Selector | 2 | Selects top samples based on ranking |
'''
Expand Down Expand Up @@ -140,6 +140,7 @@
| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range |
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
Expand Down
5 changes: 4 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
| [ Filter ]( #filter ) | 17 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 18 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 3 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |

Expand All @@ -23,6 +23,8 @@ All the specific operators are listed below, each featured with several capabili
- LaTeX: specific to LaTeX source files
- Code: specific to programming codes
- Financial: closely related to financial sector
- Image: specific to image or multimodal
- Multimodal: specific to multimodal
* Language Tags
- en: English
- zh: Chinese
Expand Down Expand Up @@ -75,6 +77,7 @@ All the specific operators are listed below, each featured with several capabili
| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range |
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
| clip_similarity_filter | Multimodal | - | Keeps samples with similarity between text and images within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
Expand Down
6 changes: 5 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 17 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 18 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 3 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |

Expand All @@ -21,6 +21,9 @@ Data-Juicer 中的算子分为以下 5 种类型。
- LaTeX: 专用于 LaTeX 源文件
- Code: 专用于编程代码
- Financial: 与金融领域相关
- Image: 专用于图像或多模态
- Multimodal: 专用于多模态

* Language 标签
- en: 英文
- zh: 中文
Expand Down Expand Up @@ -71,6 +74,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| alphanumeric_filter | General | en, zh | 保留字母数字比例在指定范围内的样本 |
| average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 |
| character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 |
| clip_similarity_filter | Multimodal | - | 保留文本图像相似度在指定范围内的样本 |
| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
| language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 |
Expand Down
1 change: 1 addition & 0 deletions environments/science_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ nlpcda
nltk
transformers
opencc==1.1.6
torch
zhijianma marked this conversation as resolved.
Show resolved Hide resolved
Binary file added tests/ops/data/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading