Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Qirui-jiao committed Sep 2, 2024
1 parent 6f79643 commit 376936b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
1 change: 1 addition & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ process:
hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface
min_score: 0.1 # the min similarity score of filter range
max_score: 1.0 # the max similarity score of filter range
text_key_second: None # used to store the other sentence in the text pair
any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition
- token_num_filter: # filter text with total token number out of specific range
hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer
Expand Down
24 changes: 17 additions & 7 deletions data_juicer/ops/filter/text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import numpy as np
from jsonargparse.typing import ClosedUnitInterval

Expand All @@ -6,6 +8,9 @@
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.model_utils import get_model, prepare_model

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

OP_NAME = 'text_pair_similarity_filter'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):
Expand All @@ -29,6 +34,7 @@ def __init__(self,
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
text_key_second=None,
any_or_all: str = 'any',
*args,
**kwargs):
Expand All @@ -39,6 +45,8 @@ def __init__(self,
the similarity between image and text.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param text_key_second: used to store the other sentence
in the text pair.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
Expand All @@ -56,7 +64,7 @@ def __init__(self,
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)
self.new_sample_key = ['target_text']
self.text_key_second = text_key_second

def compute_stats(self, sample, rank=None, context=False):

Expand All @@ -65,21 +73,23 @@ def compute_stats(self, sample, rank=None, context=False):
return sample

# there is no target text
for temp_new_key in self.new_sample_key:
if temp_new_key not in sample or len(sample[temp_new_key]) == 0:
raise ValueError(
f'Key \'{temp_new_key}\' is not found in sample. ')
if self.text_key_second is None:
logger.error('This OP (text_pair_similarity_filter) requires \
processing multiple fields, and you need to specify \
valid `text_key_second`')

# there is no text in this sample
if (self.text_key not in sample or len(sample[self.text_key]) == 0):
if (self.text_key not in sample or len(sample[self.text_key]) == 0
or self.text_key_second not in sample
or len(sample[self.text_key_second]) == 0):
sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array(
[], dtype=np.float64)
return sample

model, processor = get_model(self.model_key, rank, self.use_cuda())

text1 = sample[self.text_key]
text2 = sample['target_text']
text2 = sample[self.text_key_second]

text_tensors = processor([text1, text2],
padding=True,
Expand Down
23 changes: 14 additions & 9 deletions tests/ops/filter/test_text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

class TextPairSimilarityFilterTest(DataJuicerTestCaseBase):

hf_clip = 'openai/clip-vit-base-patch32'
hf_clip = "openai/clip-vit-base-patch32"

text_key = "text"
text_key_second = "target_text"


@classmethod
Expand All @@ -31,29 +34,31 @@ def _run_filter(self, dataset: Dataset, op, num_proc=1):
num_proc=num_proc,
with_rank=True)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.select_columns(column_names=['text', 'target_text'])
dataset = dataset.select_columns(column_names=[self.text_key,
self.text_key_second])
res_list = dataset.to_list()
print(res_list)

def test_no_eoc_special_token(self):

ds_list = [{
'target_text': 'a lovely cat',
'text': 'a lovely cat',
self.text_key_second: 'a lovely cat',
self.text_key: 'a lovely cat',
}, {
'target_text': 'a lovely cat',
'text': 'a cute cat',
self.text_key_second: 'a lovely cat',
self.text_key: 'a cute cat',
}, {
'target_text': 'a lovely cat',
'text': 'a black dog',
self.text_key_second: 'a lovely cat',
self.text_key: 'a black dog',
}]


dataset = Dataset.from_list(ds_list)
op = TextPairSimilarityFilter(hf_clip=self.hf_clip,
any_or_all='any',
min_score=0.1,
max_score=0.85)
max_score=0.99,
text_key_second=self.text_key_second)
self._run_filter(dataset, op)


Expand Down

0 comments on commit 376936b

Please sign in to comment.