diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py index 94d4af70..8e556088 100644 --- a/data_juicer/ops/filter/text_length_filter.py +++ b/data_juicer/ops/filter/text_length_filter.py @@ -33,17 +33,27 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.text_len in sample[Fields.stats]: - return sample - - sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) - return sample - - def process(self, sample): - if self.min_len <= sample[Fields.stats][ - StatsKeys.text_len] <= self.max_len: - return True + def compute_stats_batched(self, samples): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + for i, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.text_len in stat: + continue + else: + samples_stats[i][StatsKeys.text_len] = len(samples_list[i]) + + return samples + + def process_batched(self, samples): + if isinstance(samples[Fields.stats], list): + return map( + lambda stat: self.min_len <= stat[StatsKeys.text_len] <= self. + max_len, samples[Fields.stats]) else: - return False + # single sample for ray filter + if self.min_len <= samples[Fields.stats][ + StatsKeys.text_len] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 059129f4..ec163e42 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -58,57 +58,68 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.word_rep_ratio in sample[Fields.stats]: - return sample - - # try to get words from context - words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] + def compute_stats_batched(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + + for idx, stat in enumerate(samples_stats): + words_key = f'{InterVars.words}-{self.model_key}-{idx}' + # check if it's computed already + if StatsKeys.word_rep_ratio in stat: + continue + # try to get words from context + if context and words_key in samples[Fields.context]: + words = samples[Fields.context][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][words_key] = words + + # try to get refined words from context + refined_words_key = f'{InterVars.refined_words}-' \ + f'True-SPECIAL_CHARS-False-[2]-{idx}' + if context and refined_words_key in samples[Fields.context]: + words = samples[Fields.context][refined_words_key] + else: + words = words_refinement(words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS) + if context: + samples[Fields.context][refined_words_key] = words + word_ngrams = [ + ' '.join(words[i:i + self.n]) + for i in range(len(words) - self.n + 1) + ] + freq_word_ngrams = {} + for word_ngram in word_ngrams: + freq_word_ngrams[word_ngram] = ( + freq_word_ngrams.get(word_ngram, 0) + 1) + + if len(freq_word_ngrams) == 0: + samples_stats[idx][StatsKeys.word_rep_ratio] = 0.0 + continue + + freq_word_ngrams = list(freq_word_ngrams.values()) + rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] + samples_stats[idx][StatsKeys.word_rep_ratio] = ( + sum(rep_more_than_one) / + sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 + + return samples + + def process_batched(self, samples): + if isinstance(samples[Fields.stats], list): + return map( + lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio] + <= self.max_ratio, samples[Fields.stats]) else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - - # try to get refined words from context - refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \ - f'False-[2]-' - if context and refined_words_key in sample[Fields.context]: - words = sample[Fields.context][refined_words_key] - else: - words = words_refinement(words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS) - if context: - sample[Fields.context][refined_words_key] = words - word_ngrams = [ - ' '.join(words[i:i + self.n]) - for i in range(len(words) - self.n + 1) - ] - freq_word_ngrams = {} - for word_ngram in word_ngrams: - freq_word_ngrams[word_ngram] = ( - freq_word_ngrams.get(word_ngram, 0) + 1) - - if len(freq_word_ngrams) == 0: - sample[Fields.stats][StatsKeys.word_rep_ratio] = 0.0 - return sample - - freq_word_ngrams = list(freq_word_ngrams.values()) - rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] - sample[Fields.stats][StatsKeys.word_rep_ratio] = ( - sum(rep_more_than_one) / - sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 - return sample - - def process(self, sample): - if self.min_ratio <= sample[Fields.stats][StatsKeys.word_rep_ratio] \ - <= self.max_ratio: - return True - else: - return False + # single sample for ray filter + if self.min_ratio <= samples[Fields.stats][ + StatsKeys.word_rep_ratio] <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index ccd204f7..28aa4ad4 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -51,28 +51,39 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.num_words in sample[Fields.stats]: - return sample + def compute_stats_batched(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] - words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) - sample[Fields.stats][StatsKeys.num_words] = len(words) - return sample + for idx, stat in enumerate(samples_stats): + words_key = f'{InterVars.words}-{self.model_key}-{idx}' + # check if it's computed already + if StatsKeys.num_words in stat: + continue + if context and words_key in samples[Fields.context]: + words = samples[Fields.context][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][words_key] = words + words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) + samples_stats[idx][StatsKeys.num_words] = len(words) + + return samples - def process(self, sample): - if self.min_num <= sample[Fields.stats][ - StatsKeys.num_words] <= self.max_num: - return True + def process_batched(self, samples): + if isinstance(samples[Fields.stats], list): + return map( + lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self. + max_num, samples[Fields.stats]) else: - return False + # single sample for ray filter + if self.min_num <= samples[Fields.stats][ + StatsKeys.num_words] <= self.max_num: + return True + else: + return False diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py index af62bc3e..57f624b0 100644 --- a/data_juicer/ops/mapper/whitespace_normalization_mapper.py +++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py @@ -27,12 +27,15 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - def process(self, sample): - text = sample[self.text_key].strip() - - # replace all kinds of whitespaces with ' ' - sample[self.text_key] = ''.join([ - char if char not in VARIOUS_WHITESPACES else ' ' for char in text - ]) - - return sample + def process_batched(self, samples): + for idx, text in enumerate(samples[self.text_key]): + # remove whitespaces before and after the main content + text = text.strip() + + # replace all kinds of whitespaces with ' ' + samples[self.text_key][idx] = ''.join([ + char if char not in VARIOUS_WHITESPACES else ' ' + for char in text + ]) + + return samples