Skip to content

Commit

Permalink
Merge branch 'main' into service/fastapi
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Sep 10, 2024
2 parents b2b4673 + 7954241 commit 0f4fda9
Show file tree
Hide file tree
Showing 13 changed files with 242 additions and 22 deletions.
5 changes: 5 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ process:
score_threshold: 0.5 # the nsfw score threshold for samples, range from 0 to 1. Samples with nsfw score less than this threshold will be kept.
any_or_all: any # keep this sample when any/all images meet the filter condition
mem_required: '1GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- image_pair_similarity_filter: # filter samples according to the similarity score between the image pair.
hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface
min_score: 0.1 # the min similarity score of filter range
max_score: 1.0 # the max similarity score of filter range
any_or_all: "any" # keep this sample when any/all images meet the filter condition
- image_shape_filter: # filter samples according to the widths and heights of images in them
min_width: 200 # the min width of width filter range
max_width: 5000 # the max width of width filter range
Expand Down
26 changes: 18 additions & 8 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ def init_configs(args=None):
help='Suffixes of files that will be find and loaded. If not set, we '
'will find all suffix files, and select a suitable formatter '
'with the most files as default.')
parser.add_argument(
'--turbo',
type=bool,
default=False,
help='Enable Turbo mode to maximize processing speed. Stability '
'features like fault tolerance will be disabled.')
parser.add_argument(
'--use_cache',
type=bool,
Expand Down Expand Up @@ -470,6 +476,8 @@ def init_setup_from_cfg(cfg):
'image_key': cfg.image_key,
'audio_key': cfg.audio_key,
'video_key': cfg.video_key,
'num_proc': cfg.np,
'turbo': cfg.turbo,
}
else:
if 'text_key' not in args or args['text_key'] is None:
Expand All @@ -480,6 +488,10 @@ def init_setup_from_cfg(cfg):
args['audio_key'] = cfg.audio_key
if 'video_key' not in args or args['video_key'] is None:
args['video_key'] = cfg.video_key
if 'num_proc' not in args or args['num_proc'] is None:
args['num_proc'] = cfg.np
if 'turbo' not in args or args['turbo'] is None:
args['turbo'] = cfg.turbo
op[op_name] = args

return cfg
Expand Down Expand Up @@ -574,14 +586,12 @@ def update_op_process(cfg, parser):

# update op params of cfg.process
internal_op_para = temp_cfg.get(op_in_process_name)
if internal_op_para is not None:
num_proc = internal_op_para.get('num_proc')
if 'num_proc' in internal_op_para:
internal_op_para['num_proc'] = num_proc or cfg.np
internal_op_para = namespace_to_dict(internal_op_para)
else:
internal_op_para = None
cfg.process[i] = {op_in_process_name: internal_op_para}

cfg.process[i] = {
op_in_process_name:
None if internal_op_para is None else
namespace_to_dict(internal_op_para)
}

# check the op params via type hint
temp_parser = copy.deepcopy(parser)
Expand Down
9 changes: 6 additions & 3 deletions data_juicer/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ def __init__(self, cfg=None):

# setup formatter
logger.info('Setting up data formatter...')
self.formatter = load_formatter(self.cfg.dataset_path,
self.cfg.text_keys, self.cfg.suffixes,
self.cfg.add_suffix)
self.formatter = load_formatter(
dataset_path=self.cfg.dataset_path,
generated_dataset_config=self.cfg.generated_dataset_config,
text_keys=self.cfg.text_keys,
suffixes=self.cfg.suffixes,
add_suffix=self.cfg.add_suffix)

# prepare exporter and check export path suffix
# NOTICE: no need to export dataset texts for analyzer
Expand Down
12 changes: 8 additions & 4 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,15 @@ def map(self, *args, **kargs):
called_func, '__wrapped__'):
called_func = called_func.__wrapped__

# Batched is always required for fault tolerance
if inspect.ismethod(called_func):
kargs['batched'] = True
kargs['batch_size'] = kargs.pop(
'batch_size', 1) if called_func.__self__.is_batched_op() else 1
# batched is required for fault-tolerant or batched OP
if not called_func.__self__.turbo or hasattr(
called_func.__self__,
'is_batched_op') and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)
else:
kargs['batched'] = False

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
Expand Down
10 changes: 6 additions & 4 deletions data_juicer/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def __init__(self, cfg=None):

# setup formatter
logger.info('Setting up data formatter...')
self.formatter = load_formatter(self.cfg.dataset_path,
self.cfg.generated_dataset_config,
self.cfg.text_keys, self.cfg.suffixes,
self.cfg.add_suffix)
self.formatter = load_formatter(
dataset_path=self.cfg.dataset_path,
generated_dataset_config=self.cfg.generated_dataset_config,
text_keys=self.cfg.text_keys,
suffixes=self.cfg.suffixes,
add_suffix=self.cfg.add_suffix)

# whether to use checkpoint mechanism. If it's true, Executor will
# check if there are existing checkpoints first and try to load the
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(self, *args, **kwargs):
if isinstance(self.mem_required, str):
self.mem_required = size_to_bytes(self.mem_required) / 1024**3

self.turbo = kwargs.get('turbo', False)

# nested wrappers
from data_juicer.core.data import wrap_func_with_nested_access
for name in ['process', 'compute_stats', 'compute_hash']:
Expand Down
5 changes: 4 additions & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
average_line_length_filter, character_repetition_filter,
flagged_words_filter, image_aesthetics_filter,
image_aspect_ratio_filter, image_face_ratio_filter,
image_nsfw_filter, image_shape_filter, image_size_filter,
image_nsfw_filter, image_pair_similarity_filter,
image_shape_filter, image_size_filter,
image_text_matching_filter, image_text_similarity_filter,
image_watermark_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
Expand All @@ -30,6 +31,7 @@
from .image_aspect_ratio_filter import ImageAspectRatioFilter
from .image_face_ratio_filter import ImageFaceRatioFilter
from .image_nsfw_filter import ImageNSFWFilter
from .image_pair_similarity_filter import ImagePairSimilarityFilter
from .image_shape_filter import ImageShapeFilter
from .image_size_filter import ImageSizeFilter
from .image_text_matching_filter import ImageTextMatchingFilter
Expand Down Expand Up @@ -104,6 +106,7 @@
'FlaggedWordFilter',
'WordRepetitionFilter',
'VideoMotionScoreFilter',
'ImagePairSimilarityFilter'
]

# yapf: enable
114 changes: 114 additions & 0 deletions data_juicer/ops/filter/image_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.ops.base_op import OPERATORS, Filter
from data_juicer.ops.op_fusion import LOADED_IMAGES
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'image_pair_similarity_filter'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):

import torch
import transformers # noqa: F401

# avoid hanging when calling clip in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImagePairSimilarityFilter(Filter):
"""Filter to keep image pairs with similarities between images
within a specific range."""

_accelerator = 'cuda'

def __init__(self,
hf_clip='openai/clip-vit-base-patch32',
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.
:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)

def compute_stats(self, sample, rank=None, context=False):

# check if it's computed already
if StatsKeys.image_pair_similarity in sample[Fields.stats]:
return sample

# there is no image in this sample
if (self.image_key not in sample
or not len(sample[self.image_key]) == 2
or sample[self.image_key][0] == sample[self.image_key][1]):
raise ValueError('Each sample must include two images.')

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

similarity = []
model, processor = get_model(self.model_key, rank, self.use_cuda())

image_list = []
for temp_key in images.keys():
image_list.append(images[temp_key])
image_tensors = processor.image_processor(
image_list, return_tensors='pt')['pixel_values']
image1_batch_feature = model.get_image_features(
image_tensors[0].unsqueeze(0).to(model.device))
image2_batch_feature = model.get_image_features(
image_tensors[1].unsqueeze(0).to(model.device))

similarity = torch.cosine_similarity(image1_batch_feature,
image2_batch_feature,
dim=1)
sample[Fields.stats][StatsKeys.image_pair_similarity] = similarity

return sample

def process(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.image_pair_similarity]
if len(similarity) <= 0:
return True

keep_bools = np.array([
self.min_score <= sim_value <= self.max_score
for sim_value in similarity
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class StatsKeysConstant(object):
image_aesthetics_scores = 'image_aesthetics_scores'
image_nsfw_score = 'image_nsfw_score'
image_watermark_prob = 'image_watermark_prob'
image_pair_similarity = 'image_pair_similarity'

# audios
audio_duration = 'audio_duration'
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 42 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |

Expand Down Expand Up @@ -113,6 +113,7 @@ All the specific operators are listed below, each featured with several capabili
| image_aspect_ratio_filter | Image | - | Keeps samples containing images with aspect ratios within the specified range |
| image_face_ratio_filter | Image | - | Keeps samples containing images with face area ratios within the specified range |
| image_nsfw_filter | Image | - | Keeps samples containing images with NSFW scores below the threshold |
| image_pair_similarity_filter | Image | - | Keeps image pairs with image feature cosine similarity within the specified range based on a CLIP model |
| image_shape_filter | Image | - | Keeps samples containing images with widths and heights within the specified range |
| image_size_filter | Image | - | Keeps samples containing images whose size in bytes are within the specified range |
| image_text_matching_filter | Multimodal | - | Keeps samples with image-text classification matching score within the specified range based on a BLIP model |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 41 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 42 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -111,6 +111,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
| image_face_ratio_filter | Image | - | 保留样本中包含的图片的最大脸部区域在指定范围内的样本 |
| image_nsfw_filter | Image | - | 保留包含NSFW分数在指定阈值之下的图像的样本 |
| image_pair_similarity_filter | Image | - | 保留图像特征余弦相似度(基于CLIP模型)在指定范围内的样本 |
| image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 |
| image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 |
| image_text_matching_filter | Multimodal | - | 保留图像-文本的分类匹配分(基于BLIP模型)在指定范围内的样本 |
Expand Down
7 changes: 7 additions & 0 deletions tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_yaml_cfg_file(self):
'num_proc': 4,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
Expand All @@ -65,6 +66,7 @@ def test_yaml_cfg_file(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
}, 'nested dict load fail, un-expected internal value')

Expand Down Expand Up @@ -131,6 +133,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -147,6 +150,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -163,6 +167,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -179,6 +184,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})
self.assertDictEqual(
Expand All @@ -195,6 +201,7 @@ def test_mixture_cfg(self):
'stats_export_path': None,
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
}
})

Expand Down
Loading

0 comments on commit 0f4fda9

Please sign in to comment.