From 376936b3db6f269d2e408d8bb22161a3a2ec5f3b Mon Sep 17 00:00:00 2001 From: Qirui-jiao Date: Mon, 2 Sep 2024 21:20:55 +0800 Subject: [PATCH] update --- configs/config_all.yaml | 1 + .../ops/filter/text_pair_similarity_filter.py | 24 +++++++++++++------ .../test_text_pair_similarity_filter.py | 23 +++++++++++------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 8a7045daa..89ce40cb7 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -404,6 +404,7 @@ process: hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface min_score: 0.1 # the min similarity score of filter range max_score: 1.0 # the max similarity score of filter range + text_key_second: None # used to store the other sentence in the text pair any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition - token_num_filter: # filter text with total token number out of specific range hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer diff --git a/data_juicer/ops/filter/text_pair_similarity_filter.py b/data_juicer/ops/filter/text_pair_similarity_filter.py index e8151f28d..635c4c640 100644 --- a/data_juicer/ops/filter/text_pair_similarity_filter.py +++ b/data_juicer/ops/filter/text_pair_similarity_filter.py @@ -1,3 +1,5 @@ +import logging + import numpy as np from jsonargparse.typing import ClosedUnitInterval @@ -6,6 +8,9 @@ from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + OP_NAME = 'text_pair_similarity_filter' with AvailabilityChecking(['torch', 'transformers'], OP_NAME): @@ -29,6 +34,7 @@ def __init__(self, trust_remote_code=False, min_score: ClosedUnitInterval = 0.1, max_score: ClosedUnitInterval = 1.0, + text_key_second=None, any_or_all: str = 'any', *args, **kwargs): @@ -39,6 +45,8 @@ def __init__(self, the similarity between image and text. :param min_score: The min similarity to keep samples. :param max_score: The max similarity to keep samples. + :param text_key_second: used to store the other sentence + in the text pair. :param any_or_all: keep this sample with 'any' or 'all' strategy of all images. 'any': keep this sample if any images meet the condition. 'all': keep this sample only if all images meet the @@ -56,7 +64,7 @@ def __init__(self, self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_clip, trust_remote_code=trust_remote_code) - self.new_sample_key = ['target_text'] + self.text_key_second = text_key_second def compute_stats(self, sample, rank=None, context=False): @@ -65,13 +73,15 @@ def compute_stats(self, sample, rank=None, context=False): return sample # there is no target text - for temp_new_key in self.new_sample_key: - if temp_new_key not in sample or len(sample[temp_new_key]) == 0: - raise ValueError( - f'Key \'{temp_new_key}\' is not found in sample. ') + if self.text_key_second is None: + logger.error('This OP (text_pair_similarity_filter) requires \ + processing multiple fields, and you need to specify \ + valid `text_key_second`') # there is no text in this sample - if (self.text_key not in sample or len(sample[self.text_key]) == 0): + if (self.text_key not in sample or len(sample[self.text_key]) == 0 + or self.text_key_second not in sample + or len(sample[self.text_key_second]) == 0): sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array( [], dtype=np.float64) return sample @@ -79,7 +89,7 @@ def compute_stats(self, sample, rank=None, context=False): model, processor = get_model(self.model_key, rank, self.use_cuda()) text1 = sample[self.text_key] - text2 = sample['target_text'] + text2 = sample[self.text_key_second] text_tensors = processor([text1, text2], padding=True, diff --git a/tests/ops/filter/test_text_pair_similarity_filter.py b/tests/ops/filter/test_text_pair_similarity_filter.py index b65849bc9..083849443 100644 --- a/tests/ops/filter/test_text_pair_similarity_filter.py +++ b/tests/ops/filter/test_text_pair_similarity_filter.py @@ -11,7 +11,10 @@ class TextPairSimilarityFilterTest(DataJuicerTestCaseBase): - hf_clip = 'openai/clip-vit-base-patch32' + hf_clip = "openai/clip-vit-base-patch32" + + text_key = "text" + text_key_second = "target_text" @classmethod @@ -31,21 +34,22 @@ def _run_filter(self, dataset: Dataset, op, num_proc=1): num_proc=num_proc, with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) - dataset = dataset.select_columns(column_names=['text', 'target_text']) + dataset = dataset.select_columns(column_names=[self.text_key, + self.text_key_second]) res_list = dataset.to_list() print(res_list) def test_no_eoc_special_token(self): ds_list = [{ - 'target_text': 'a lovely cat', - 'text': 'a lovely cat', + self.text_key_second: 'a lovely cat', + self.text_key: 'a lovely cat', }, { - 'target_text': 'a lovely cat', - 'text': 'a cute cat', + self.text_key_second: 'a lovely cat', + self.text_key: 'a cute cat', }, { - 'target_text': 'a lovely cat', - 'text': 'a black dog', + self.text_key_second: 'a lovely cat', + self.text_key: 'a black dog', }] @@ -53,7 +57,8 @@ def test_no_eoc_special_token(self): op = TextPairSimilarityFilter(hf_clip=self.hf_clip, any_or_all='any', min_score=0.1, - max_score=0.85) + max_score=0.99, + text_key_second=self.text_key_second) self._run_filter(dataset, op)