Skip to content

Commit

Permalink
1. truncation text inputs.
Browse files Browse the repository at this point in the history
2. fix hanging when multi-process.
  • Loading branch information
zhijianma committed Nov 14, 2023
1 parent af7101d commit a7e2b2e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
14 changes: 10 additions & 4 deletions data_juicer/ops/filter/clip_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import torch
from jsonargparse.typing import PositiveFloat

from data_juicer.utils.constant import Fields, StatsKeys
Expand All @@ -8,6 +9,9 @@
from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES

# avoid hanging when calling clip in multiprocessing
torch.get_num_threads()


@OPERATORS.register_module('clip_similarity_filter')
@LOADED_IMAGES.register_module('clip_similarity_filter')
Expand Down Expand Up @@ -43,7 +47,6 @@ def __init__(self,
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.image_key = 'images'
self.min_ratio = min_ratio
self.max_ratio = max_ratio
if reduce_mode not in ['avg', 'max', 'min']:
Expand Down Expand Up @@ -117,20 +120,23 @@ def remove_special_token(text):
inputs = processor(text=text_chunk,
images=image_chunk,
return_tensors='pt',
truncation=True,
max_length=model.config.text_config.
max_position_embeddings,
padding=True)

outputs = model(**inputs)
chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0

if self.reduce_mode == 'avg':
chunk_similarity = chunk_logits.mean()
elif self.reduce_mode == 'max':
chunk_similarity = chunk_logits.max()
else:
chunk_similarity = chunk_logits.min()

similarity.append(float(chunk_similarity))
offset += count

similarity.append(float(chunk_similarity))
offset += count
sample[Fields.stats][StatsKeys.clip_image_text_similarity] = similarity

return sample
Expand Down
34 changes: 28 additions & 6 deletions tests/ops/filter/test_clip_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ class ClipSimilarityFilterTest(unittest.TestCase):
img3_path = os.path.join(data_path, 'img3.jpg')
hf_clip = 'openai/clip-vit-base-patch32'

def _run_filter(self, dataset: Dataset, target_list, op):
def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1):

if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)

dataset = dataset.map(op.compute_stats, num_proc=num_proc)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.select_columns(column_names=['text', 'images'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_reduce_avg(self):
max_ratio=0.9)
self._run_filter(dataset, tgt_list, op)

def xxtest_reduce_max(self):
def test_reduce_max(self):

ds_list = [{
'text': f'{SpecialTokens.image}a photo of a cat '
Expand Down Expand Up @@ -182,6 +182,28 @@ def test_reduce_min(self):
op.min_ratio = 0.2
self._run_filter(dataset, [], op)

def test_multi_process(self):

ds_list = [{
'text':
f'{SpecialTokens.image}a photo of a cat {SpecialTokens.eoc} '
f'{SpecialTokens.image}a photo of a dog {SpecialTokens.eoc}',
'images': [self.cat_path, self.cat_path]
}] * 10
tgt_list = [{
'text':
f'{SpecialTokens.image}a photo of a cat {SpecialTokens.eoc} '
f'{SpecialTokens.image}a photo of a dog {SpecialTokens.eoc}',
'images': [self.cat_path, self.cat_path]
}] * 10
dataset = Dataset.from_list(ds_list)
op = ClipSimilarityFilter(hf_clip=self.hf_clip,
reduce_mode='avg',
any_or_all='any',
min_ratio=0.2,
max_ratio=0.9)
self._run_filter(dataset, tgt_list, op, num_proc=4)


if __name__ == '__main__':
unittest.main()

0 comments on commit a7e2b2e

Please sign in to comment.