From 8bf87d0daccdb190c617e51824a255c818a2609e Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Mon, 14 Oct 2024 17:25:57 +0800 Subject: [PATCH 01/13] * extract batched to an outer func * modified for text_length_filter --- data_juicer/ops/base_op.py | 18 ++++++++-- data_juicer/ops/filter/text_length_filter.py | 36 +++++++------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 901e8523e..6d329b2f5 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -289,6 +289,20 @@ def __init__(self, *args, **kwargs): else: self.compute_stats = catch_map_single_exception(self.compute_stats) + def compute_stats_batched(self, samples, **kwargs): + keys = samples.keys() + samples_stats = samples[Fields.stats] + for i, stat in enumerate(samples_stats): + this_sample = {key: samples[key][i] for key in keys} + res_sample = self.compute_stats(this_sample, **kwargs) + samples[Fields.stats][i] = res_sample[Fields.stats] + + return samples + + def process_batched(self, samples): + return map(lambda stat: self.process({Fields.stats: stat}), + samples[Fields.stats]) + def compute_stats(self, sample, context=False): """ Compute stats for the sample which is used as a metric to decide @@ -322,14 +336,14 @@ def run(self, dataset, *, exporter=None, tracer=None): num_proc=self.runtime_np(), batch_size=self.batch_size, desc='Adding new column for stats') - dataset = dataset.map(self.compute_stats, + dataset = dataset.map(self.compute_stats_batched, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_compute_stats') if exporter and self.stats_export_path is not None: exporter.export_compute_stats(dataset, self.stats_export_path) - new_dataset = dataset.filter(self.process, + new_dataset = dataset.filter(self.process_batched, num_proc=self.runtime_np(), batch_size=self.batch_size, desc=self._name + '_process') diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py index 51e0bd68d..94d4af704 100644 --- a/data_juicer/ops/filter/text_length_filter.py +++ b/data_juicer/ops/filter/text_length_filter.py @@ -33,27 +33,17 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, samples): - samples_list = samples[self.text_key] - samples_stats = samples[Fields.stats] - for i, stat in enumerate(samples_stats): - # check if it's computed already - if StatsKeys.text_len in stat: - continue - else: - samples_stats[i][StatsKeys.text_len] = len(samples_list[i]) - - return samples - - def process(self, samples): - if isinstance(samples[Fields.stats], list): - return map( - lambda stat: self.min_len <= stat[StatsKeys.text_len] <= self. - max_len, samples[Fields.stats]) + def compute_stats(self, sample): + # check if it's computed already + if StatsKeys.text_len in sample[Fields.stats]: + return sample + + sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) + return sample + + def process(self, sample): + if self.min_len <= sample[Fields.stats][ + StatsKeys.text_len] <= self.max_len: + return True else: - # single sample for ray filter - if self.min_len <= samples[Fields.stats][ - StatsKeys.text_len] <= self.max_len: - return True - else: - return False + return False From 4cd9e62ef4643cf286f3c01583adcd12d6db6208 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Mon, 14 Oct 2024 17:46:03 +0800 Subject: [PATCH 02/13] * use a branch to decide which funcs are used --- data_juicer/ops/base_op.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 6d329b2f5..ee9fe7a0c 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -284,15 +284,19 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.compute_stats = catch_map_batches_exception( - self.compute_stats) + self.compute_stats_branch = catch_map_batches_exception( + self.compute_stats_batched) + self.process_branch = catch_map_batches_exception( + self.process_batched) else: - self.compute_stats = catch_map_single_exception(self.compute_stats) + self.compute_stats_branch = catch_map_single_exception( + self.compute_stats) + self.process_branch = catch_map_single_exception(self.process) def compute_stats_batched(self, samples, **kwargs): keys = samples.keys() samples_stats = samples[Fields.stats] - for i, stat in enumerate(samples_stats): + for i, _ in enumerate(samples_stats): this_sample = {key: samples[key][i] for key in keys} res_sample = self.compute_stats(this_sample, **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] @@ -336,14 +340,14 @@ def run(self, dataset, *, exporter=None, tracer=None): num_proc=self.runtime_np(), batch_size=self.batch_size, desc='Adding new column for stats') - dataset = dataset.map(self.compute_stats_batched, + dataset = dataset.map(self.compute_stats_branch, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_compute_stats') if exporter and self.stats_export_path is not None: exporter.export_compute_stats(dataset, self.stats_export_path) - new_dataset = dataset.filter(self.process_batched, + new_dataset = dataset.filter(self.process_branch, num_proc=self.runtime_np(), batch_size=self.batch_size, desc=self._name + '_process') From ad5dbba7a88e8d0e7cb4966ff3373afd5c3155d7 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Mon, 14 Oct 2024 17:48:37 +0800 Subject: [PATCH 03/13] * update context for each sample as well --- data_juicer/ops/base_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index ee9fe7a0c..384b243e9 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -300,6 +300,8 @@ def compute_stats_batched(self, samples, **kwargs): this_sample = {key: samples[key][i] for key in keys} res_sample = self.compute_stats(this_sample, **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] + if 'context' in kwargs and kwargs['context']: + samples[Fields.context][i] = res_sample[Fields.context] return samples From 4d6ffe1150ee2f54fe8ecac7e551956923c9f0ec Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Mon, 14 Oct 2024 21:11:29 +0800 Subject: [PATCH 04/13] * modify for mapper and whitespace_normalization_mapper --- data_juicer/ops/base_op.py | 20 ++++++++++++++---- .../mapper/whitespace_normalization_mapper.py | 21 ++++++++----------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 384b243e9..9b22f63ae 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -236,9 +236,21 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.process = catch_map_batches_exception(self.process) + self.process_branch = catch_map_batches_exception( + self.process_batched) else: - self.process = catch_map_single_exception(self.process) + self.process_branch = catch_map_single_exception(self.process) + + def process_batched(self, samples): + keys = samples.keys() + first_key = list(keys)[0] + for i in range(len(samples[first_key])): + this_sample = {key: samples[key][i] for key in keys} + res_sample = self.process(this_sample) + for key in keys: + samples[key][i] = res_sample[key] + + return samples def process(self, sample): """ @@ -252,7 +264,7 @@ def process(self, sample): def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Mapper, self).run(dataset) new_dataset = dataset.map( - self.process, + self.process_branch, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, @@ -296,7 +308,7 @@ def __init__(self, *args, **kwargs): def compute_stats_batched(self, samples, **kwargs): keys = samples.keys() samples_stats = samples[Fields.stats] - for i, _ in enumerate(samples_stats): + for i in range(len(samples_stats)): this_sample = {key: samples[key][i] for key in keys} res_sample = self.compute_stats(this_sample, **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py index 3102cedab..af62bc3e7 100644 --- a/data_juicer/ops/mapper/whitespace_normalization_mapper.py +++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py @@ -27,15 +27,12 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - def process(self, samples): - for idx, text in enumerate(samples[self.text_key]): - # remove whitespaces before and after the main content - text = text.strip() - - # replace all kinds of whitespaces with ' ' - samples[self.text_key][idx] = ''.join([ - char if char not in VARIOUS_WHITESPACES else ' ' - for char in text - ]) - - return samples + def process(self, sample): + text = sample[self.text_key].strip() + + # replace all kinds of whitespaces with ' ' + sample[self.text_key] = ''.join([ + char if char not in VARIOUS_WHITESPACES else ' ' for char in text + ]) + + return sample From 1929ddd7081cf96688cb0a4af4a575a686584fe4 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Mon, 14 Oct 2024 21:41:14 +0800 Subject: [PATCH 05/13] * modify for two filters with context --- .../ops/filter/word_repetition_filter.py | 115 ++++++++---------- data_juicer/ops/filter/words_num_filter.py | 55 ++++----- 2 files changed, 74 insertions(+), 96 deletions(-) diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 41a081694..059129f41 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -58,68 +58,57 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, samples, context=False): - samples_list = samples[self.text_key] - samples_stats = samples[Fields.stats] - words_key = f'{InterVars.words}-{self.model_key}' + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.word_rep_ratio in sample[Fields.stats]: + return sample - for idx, stat in enumerate(samples_stats): - # check if it's computed already - if StatsKeys.word_rep_ratio in stat: - continue - # try to get words from context - if context and words_key in samples[Fields.context][idx]: - words = samples[Fields.context][idx][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - samples_list[idx], - token_func=tokenizer.encode_as_pieces - if tokenizer else None) - if context: - samples[Fields.context][idx][words_key] = words - - # try to get refined words from context - refined_words_key = f'{InterVars.refined_words}-' \ - f'True-SPECIAL_CHARS-False-[2]-' - if context and refined_words_key in samples[Fields.context][idx]: - words = samples[Fields.context][idx][refined_words_key] - else: - words = words_refinement(words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS) - if context: - samples[Fields.context][idx][refined_words_key] = words - word_ngrams = [ - ' '.join(words[i:i + self.n]) - for i in range(len(words) - self.n + 1) - ] - freq_word_ngrams = {} - for word_ngram in word_ngrams: - freq_word_ngrams[word_ngram] = ( - freq_word_ngrams.get(word_ngram, 0) + 1) - - if len(freq_word_ngrams) == 0: - samples_stats[idx][StatsKeys.word_rep_ratio] = 0.0 - continue - - freq_word_ngrams = list(freq_word_ngrams.values()) - rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] - samples_stats[idx][StatsKeys.word_rep_ratio] = ( - sum(rep_more_than_one) / - sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 - - return samples - - def process(self, samples): - if isinstance(samples[Fields.stats], list): - return map( - lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio] - <= self.max_ratio, samples[Fields.stats]) + # try to get words from context + words_key = f'{InterVars.words}-{self.model_key}' + if context and words_key in sample[Fields.context]: + words = sample[Fields.context][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None) + if context: + sample[Fields.context][words_key] = words + + # try to get refined words from context + refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \ + f'False-[2]-' + if context and refined_words_key in sample[Fields.context]: + words = sample[Fields.context][refined_words_key] + else: + words = words_refinement(words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS) + if context: + sample[Fields.context][refined_words_key] = words + word_ngrams = [ + ' '.join(words[i:i + self.n]) + for i in range(len(words) - self.n + 1) + ] + freq_word_ngrams = {} + for word_ngram in word_ngrams: + freq_word_ngrams[word_ngram] = ( + freq_word_ngrams.get(word_ngram, 0) + 1) + + if len(freq_word_ngrams) == 0: + sample[Fields.stats][StatsKeys.word_rep_ratio] = 0.0 + return sample + + freq_word_ngrams = list(freq_word_ngrams.values()) + rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] + sample[Fields.stats][StatsKeys.word_rep_ratio] = ( + sum(rep_more_than_one) / + sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 + return sample + + def process(self, sample): + if self.min_ratio <= sample[Fields.stats][StatsKeys.word_rep_ratio] \ + <= self.max_ratio: + return True else: - # single sample for ray filter - if self.min_ratio <= samples[Fields.stats][ - StatsKeys.word_rep_ratio] <= self.max_ratio: - return True - else: - return False + return False diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index 413a2171d..ccd204f7e 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -51,39 +51,28 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, samples, context=False): - samples_list = samples[self.text_key] - samples_stats = samples[Fields.stats] - words_key = f'{InterVars.words}-{self.model_key}' - - for idx, stat in enumerate(samples_stats): - # check if it's computed already - if StatsKeys.num_words in stat: - continue - if context and words_key in samples[Fields.context][idx]: - words = samples[Fields.context][idx][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - samples_list[idx], - token_func=tokenizer.encode_as_pieces - if tokenizer else None) - if context: - samples[Fields.context][idx][words_key] = words - words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) - samples_stats[idx][StatsKeys.num_words] = len(words) + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.num_words in sample[Fields.stats]: + return sample - return samples + words_key = f'{InterVars.words}-{self.model_key}' + if context and words_key in sample[Fields.context]: + words = sample[Fields.context][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + sample[self.text_key], + token_func=tokenizer.encode_as_pieces if tokenizer else None) + if context: + sample[Fields.context][words_key] = words + words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) + sample[Fields.stats][StatsKeys.num_words] = len(words) + return sample - def process(self, samples): - if isinstance(samples[Fields.stats], list): - return map( - lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self. - max_num, samples[Fields.stats]) + def process(self, sample): + if self.min_num <= sample[Fields.stats][ + StatsKeys.num_words] <= self.max_num: + return True else: - # single sample for ray filter - if self.min_num <= samples[Fields.stats][ - StatsKeys.num_words] <= self.max_num: - return True - else: - return False + return False From b98f6d9a650cdfd134db356892be83dc1bdf7fd6 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Tue, 15 Oct 2024 16:19:38 +0800 Subject: [PATCH 06/13] * allow optional args for batched funcs --- data_juicer/ops/base_op.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 9b22f63ae..c0da2abd0 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -241,12 +241,12 @@ def __init__(self, *args, **kwargs): else: self.process_branch = catch_map_single_exception(self.process) - def process_batched(self, samples): + def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = list(keys)[0] for i in range(len(samples[first_key])): this_sample = {key: samples[key][i] for key in keys} - res_sample = self.process(this_sample) + res_sample = self.process(this_sample, *args, **kwargs) for key in keys: samples[key][i] = res_sample[key] @@ -305,12 +305,12 @@ def __init__(self, *args, **kwargs): self.compute_stats) self.process_branch = catch_map_single_exception(self.process) - def compute_stats_batched(self, samples, **kwargs): + def compute_stats_batched(self, samples, *args, **kwargs): keys = samples.keys() samples_stats = samples[Fields.stats] for i in range(len(samples_stats)): this_sample = {key: samples[key][i] for key in keys} - res_sample = self.compute_stats(this_sample, **kwargs) + res_sample = self.compute_stats(this_sample, *args, **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] if 'context' in kwargs and kwargs['context']: samples[Fields.context][i] = res_sample[Fields.context] From c2188d515f909b145f7aba26a91987ed48f300f2 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Tue, 15 Oct 2024 17:49:57 +0800 Subject: [PATCH 07/13] * restore to batched version and rename to xxx_batched --- data_juicer/ops/filter/text_length_filter.py | 36 ++++-- .../ops/filter/word_repetition_filter.py | 117 ++++++++++-------- data_juicer/ops/filter/words_num_filter.py | 55 ++++---- .../mapper/whitespace_normalization_mapper.py | 21 ++-- 4 files changed, 132 insertions(+), 97 deletions(-) diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py index 94d4af704..8e5560884 100644 --- a/data_juicer/ops/filter/text_length_filter.py +++ b/data_juicer/ops/filter/text_length_filter.py @@ -33,17 +33,27 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.text_len in sample[Fields.stats]: - return sample - - sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) - return sample - - def process(self, sample): - if self.min_len <= sample[Fields.stats][ - StatsKeys.text_len] <= self.max_len: - return True + def compute_stats_batched(self, samples): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + for i, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.text_len in stat: + continue + else: + samples_stats[i][StatsKeys.text_len] = len(samples_list[i]) + + return samples + + def process_batched(self, samples): + if isinstance(samples[Fields.stats], list): + return map( + lambda stat: self.min_len <= stat[StatsKeys.text_len] <= self. + max_len, samples[Fields.stats]) else: - return False + # single sample for ray filter + if self.min_len <= samples[Fields.stats][ + StatsKeys.text_len] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 059129f41..ec163e429 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -58,57 +58,68 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.word_rep_ratio in sample[Fields.stats]: - return sample - - # try to get words from context - words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] + def compute_stats_batched(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + + for idx, stat in enumerate(samples_stats): + words_key = f'{InterVars.words}-{self.model_key}-{idx}' + # check if it's computed already + if StatsKeys.word_rep_ratio in stat: + continue + # try to get words from context + if context and words_key in samples[Fields.context]: + words = samples[Fields.context][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][words_key] = words + + # try to get refined words from context + refined_words_key = f'{InterVars.refined_words}-' \ + f'True-SPECIAL_CHARS-False-[2]-{idx}' + if context and refined_words_key in samples[Fields.context]: + words = samples[Fields.context][refined_words_key] + else: + words = words_refinement(words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS) + if context: + samples[Fields.context][refined_words_key] = words + word_ngrams = [ + ' '.join(words[i:i + self.n]) + for i in range(len(words) - self.n + 1) + ] + freq_word_ngrams = {} + for word_ngram in word_ngrams: + freq_word_ngrams[word_ngram] = ( + freq_word_ngrams.get(word_ngram, 0) + 1) + + if len(freq_word_ngrams) == 0: + samples_stats[idx][StatsKeys.word_rep_ratio] = 0.0 + continue + + freq_word_ngrams = list(freq_word_ngrams.values()) + rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] + samples_stats[idx][StatsKeys.word_rep_ratio] = ( + sum(rep_more_than_one) / + sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 + + return samples + + def process_batched(self, samples): + if isinstance(samples[Fields.stats], list): + return map( + lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio] + <= self.max_ratio, samples[Fields.stats]) else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - - # try to get refined words from context - refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \ - f'False-[2]-' - if context and refined_words_key in sample[Fields.context]: - words = sample[Fields.context][refined_words_key] - else: - words = words_refinement(words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS) - if context: - sample[Fields.context][refined_words_key] = words - word_ngrams = [ - ' '.join(words[i:i + self.n]) - for i in range(len(words) - self.n + 1) - ] - freq_word_ngrams = {} - for word_ngram in word_ngrams: - freq_word_ngrams[word_ngram] = ( - freq_word_ngrams.get(word_ngram, 0) + 1) - - if len(freq_word_ngrams) == 0: - sample[Fields.stats][StatsKeys.word_rep_ratio] = 0.0 - return sample - - freq_word_ngrams = list(freq_word_ngrams.values()) - rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] - sample[Fields.stats][StatsKeys.word_rep_ratio] = ( - sum(rep_more_than_one) / - sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 - return sample - - def process(self, sample): - if self.min_ratio <= sample[Fields.stats][StatsKeys.word_rep_ratio] \ - <= self.max_ratio: - return True - else: - return False + # single sample for ray filter + if self.min_ratio <= samples[Fields.stats][ + StatsKeys.word_rep_ratio] <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index ccd204f7e..28aa4ad49 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -51,28 +51,39 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.num_words in sample[Fields.stats]: - return sample + def compute_stats_batched(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] - words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) - sample[Fields.stats][StatsKeys.num_words] = len(words) - return sample + for idx, stat in enumerate(samples_stats): + words_key = f'{InterVars.words}-{self.model_key}-{idx}' + # check if it's computed already + if StatsKeys.num_words in stat: + continue + if context and words_key in samples[Fields.context]: + words = samples[Fields.context][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][words_key] = words + words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) + samples_stats[idx][StatsKeys.num_words] = len(words) + + return samples - def process(self, sample): - if self.min_num <= sample[Fields.stats][ - StatsKeys.num_words] <= self.max_num: - return True + def process_batched(self, samples): + if isinstance(samples[Fields.stats], list): + return map( + lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self. + max_num, samples[Fields.stats]) else: - return False + # single sample for ray filter + if self.min_num <= samples[Fields.stats][ + StatsKeys.num_words] <= self.max_num: + return True + else: + return False diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py index af62bc3e7..57f624b06 100644 --- a/data_juicer/ops/mapper/whitespace_normalization_mapper.py +++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py @@ -27,12 +27,15 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - def process(self, sample): - text = sample[self.text_key].strip() - - # replace all kinds of whitespaces with ' ' - sample[self.text_key] = ''.join([ - char if char not in VARIOUS_WHITESPACES else ' ' for char in text - ]) - - return sample + def process_batched(self, samples): + for idx, text in enumerate(samples[self.text_key]): + # remove whitespaces before and after the main content + text = text.strip() + + # replace all kinds of whitespaces with ' ' + samples[self.text_key][idx] = ''.join([ + char if char not in VARIOUS_WHITESPACES else ' ' + for char in text + ]) + + return samples From 281c68d6d1ae8c71c55aab968238d333572c80f4 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Tue, 15 Oct 2024 17:54:04 +0800 Subject: [PATCH 08/13] * restore to batched version and rename to xxx_batched --- data_juicer/ops/filter/word_repetition_filter.py | 16 ++++++++-------- data_juicer/ops/filter/words_num_filter.py | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index ec163e429..46843e7d9 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -61,15 +61,15 @@ def __init__(self, def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] + words_key = f'{InterVars.words}-{self.model_key}' for idx, stat in enumerate(samples_stats): - words_key = f'{InterVars.words}-{self.model_key}-{idx}' # check if it's computed already if StatsKeys.word_rep_ratio in stat: continue # try to get words from context - if context and words_key in samples[Fields.context]: - words = samples[Fields.context][words_key] + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] else: tokenizer = get_model(self.model_key) words = get_words_from_document( @@ -77,19 +77,19 @@ def compute_stats_batched(self, samples, context=False): token_func=tokenizer.encode_as_pieces if tokenizer else None) if context: - samples[Fields.context][words_key] = words + samples[Fields.context][idx][words_key] = words # try to get refined words from context refined_words_key = f'{InterVars.refined_words}-' \ - f'True-SPECIAL_CHARS-False-[2]-{idx}' - if context and refined_words_key in samples[Fields.context]: - words = samples[Fields.context][refined_words_key] + f'True-SPECIAL_CHARS-False-[2]-' + if context and refined_words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][refined_words_key] else: words = words_refinement(words, lower_case=True, strip_chars=SPECIAL_CHARACTERS) if context: - samples[Fields.context][refined_words_key] = words + samples[Fields.context][idx][refined_words_key] = words word_ngrams = [ ' '.join(words[i:i + self.n]) for i in range(len(words) - self.n + 1) diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index 28aa4ad49..21e0fe9c9 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -54,14 +54,14 @@ def __init__(self, def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] + words_key = f'{InterVars.words}-{self.model_key}' for idx, stat in enumerate(samples_stats): - words_key = f'{InterVars.words}-{self.model_key}-{idx}' # check if it's computed already if StatsKeys.num_words in stat: continue - if context and words_key in samples[Fields.context]: - words = samples[Fields.context][words_key] + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] else: tokenizer = get_model(self.model_key) words = get_words_from_document( @@ -69,7 +69,7 @@ def compute_stats_batched(self, samples, context=False): token_func=tokenizer.encode_as_pieces if tokenizer else None) if context: - samples[Fields.context][words_key] = words + samples[Fields.context][idx][words_key] = words words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) samples_stats[idx][StatsKeys.num_words] = len(words) From 5bf31737631fc54f52809ad4bd406db3895ba2e4 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Thu, 17 Oct 2024 10:30:44 +0800 Subject: [PATCH 09/13] * restore to batched version and rename to xxx_batched --- data_juicer/ops/base_op.py | 44 +++++++++---------- data_juicer/ops/filter/alphanumeric_filter.py | 4 +- .../ops/filter/audio_duration_filter.py | 4 +- .../ops/filter/audio_nmf_snr_filter.py | 4 +- data_juicer/ops/filter/audio_size_filter.py | 4 +- .../ops/filter/average_line_length_filter.py | 4 +- .../ops/filter/character_repetition_filter.py | 4 +- .../ops/filter/flagged_words_filter.py | 4 +- .../ops/filter/image_aesthetics_filter.py | 4 +- .../ops/filter/image_aspect_ratio_filter.py | 4 +- .../ops/filter/image_face_ratio_filter.py | 4 +- data_juicer/ops/filter/image_nsfw_filter.py | 4 +- .../filter/image_pair_similarity_filter.py | 4 +- data_juicer/ops/filter/image_shape_filter.py | 4 +- data_juicer/ops/filter/image_size_filter.py | 4 +- .../ops/filter/image_text_matching_filter.py | 4 +- .../filter/image_text_similarity_filter.py | 4 +- .../ops/filter/image_watermark_filter.py | 4 +- .../ops/filter/language_id_score_filter.py | 4 +- .../ops/filter/maximum_line_length_filter.py | 4 +- data_juicer/ops/filter/perplexity_filter.py | 4 +- .../filter/phrase_grounding_recall_filter.py | 4 +- .../ops/filter/special_characters_filter.py | 4 +- .../ops/filter/specified_field_filter.py | 4 +- .../filter/specified_numeric_field_filter.py | 4 +- data_juicer/ops/filter/stopwords_filter.py | 4 +- data_juicer/ops/filter/suffix_filter.py | 4 +- data_juicer/ops/filter/text_action_filter.py | 4 +- .../filter/text_entity_dependency_filter.py | 4 +- data_juicer/ops/filter/token_num_filter.py | 4 +- .../ops/filter/video_aesthetics_filter.py | 4 +- .../ops/filter/video_aspect_ratio_filter.py | 4 +- .../ops/filter/video_duration_filter.py | 4 +- .../video_frames_text_similarity_filter.py | 4 +- .../ops/filter/video_motion_score_filter.py | 4 +- data_juicer/ops/filter/video_nsfw_filter.py | 4 +- .../ops/filter/video_ocr_area_ratio_filter.py | 4 +- .../ops/filter/video_resolution_filter.py | 4 +- .../video_tagging_from_frames_filter.py | 4 +- .../ops/filter/video_watermark_filter.py | 4 +- .../ops/mapper/audio_ffmpeg_wrapped_mapper.py | 2 +- .../ops/mapper/chinese_convert_mapper.py | 2 +- .../ops/mapper/clean_copyright_mapper.py | 2 +- data_juicer/ops/mapper/clean_email_mapper.py | 2 +- data_juicer/ops/mapper/clean_html_mapper.py | 2 +- data_juicer/ops/mapper/clean_ip_mapper.py | 2 +- data_juicer/ops/mapper/clean_links_mapper.py | 2 +- data_juicer/ops/mapper/expand_macro_mapper.py | 2 +- data_juicer/ops/mapper/extract_qa_mapper.py | 2 +- data_juicer/ops/mapper/fix_unicode_mapper.py | 2 +- .../ops/mapper/generate_instruction_mapper.py | 2 +- data_juicer/ops/mapper/image_blur_mapper.py | 2 +- .../image_captioning_from_gpt4v_mapper.py | 2 +- .../ops/mapper/image_captioning_mapper.py | 2 +- .../ops/mapper/image_diffusion_mapper.py | 2 +- .../ops/mapper/image_face_blur_mapper.py | 2 +- .../ops/mapper/image_tagging_mapper.py | 2 +- data_juicer/ops/mapper/nlpaug_en_mapper.py | 2 +- data_juicer/ops/mapper/nlpcda_zh_mapper.py | 2 +- .../ops/mapper/optimize_instruction_mapper.py | 2 +- .../punctuation_normalization_mapper.py | 2 +- .../ops/mapper/remove_bibliography_mapper.py | 2 +- .../ops/mapper/remove_comments_mapper.py | 2 +- .../ops/mapper/remove_header_mapper.py | 2 +- .../ops/mapper/remove_long_words_mapper.py | 2 +- .../remove_non_chinese_character_mapper.py | 2 +- .../mapper/remove_repeat_sentences_mapper.py | 2 +- .../mapper/remove_specific_chars_mapper.py | 2 +- .../ops/mapper/remove_table_text_mapper.py | 2 +- ..._words_with_incorrect_substrings_mapper.py | 2 +- .../ops/mapper/replace_content_mapper.py | 2 +- .../ops/mapper/sentence_split_mapper.py | 2 +- .../video_captioning_from_audio_mapper.py | 2 +- .../video_captioning_from_frames_mapper.py | 2 +- ...video_captioning_from_summarizer_mapper.py | 2 +- .../video_captioning_from_video_mapper.py | 2 +- .../ops/mapper/video_face_blur_mapper.py | 2 +- .../ops/mapper/video_ffmpeg_wrapped_mapper.py | 2 +- .../mapper/video_remove_watermark_mapper.py | 2 +- .../video_resize_aspect_ratio_mapper.py | 2 +- .../mapper/video_resize_resolution_mapper.py | 2 +- .../mapper/video_split_by_duration_mapper.py | 2 +- .../mapper/video_split_by_key_frame_mapper.py | 2 +- .../ops/mapper/video_split_by_scene_mapper.py | 2 +- .../mapper/video_tagging_from_audio_mapper.py | 2 +- .../video_tagging_from_frames_mapper.py | 2 +- 86 files changed, 146 insertions(+), 146 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index c0da2abd0..ef4307a8c 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -236,23 +236,23 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.process_branch = catch_map_batches_exception( - self.process_batched) + self.process = catch_map_batches_exception(self.process_batched) else: - self.process_branch = catch_map_single_exception(self.process) + self.process = catch_map_single_exception(self.process_single) def process_batched(self, samples, *args, **kwargs): keys = samples.keys() - first_key = list(keys)[0] - for i in range(len(samples[first_key])): + first_key = next(iter(keys)) + num_samples = len(samples[first_key]) + for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} - res_sample = self.process(this_sample, *args, **kwargs) + res_sample = self.process_single(this_sample, *args, **kwargs) for key in keys: samples[key][i] = res_sample[key] return samples - def process(self, sample): + def process_single(self, sample): """ For sample level, sample --> sample @@ -264,7 +264,7 @@ def process(self, sample): def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Mapper, self).run(dataset) new_dataset = dataset.map( - self.process_branch, + self.process, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, @@ -296,21 +296,21 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.compute_stats_branch = catch_map_batches_exception( + self.compute_stats = catch_map_batches_exception( self.compute_stats_batched) - self.process_branch = catch_map_batches_exception( - self.process_batched) + self.process = catch_map_batches_exception(self.process_batched) else: - self.compute_stats_branch = catch_map_single_exception( - self.compute_stats) - self.process_branch = catch_map_single_exception(self.process) + self.compute_stats = catch_map_single_exception( + self.compute_stats_single) + self.process = catch_map_single_exception(self.process_single) def compute_stats_batched(self, samples, *args, **kwargs): keys = samples.keys() - samples_stats = samples[Fields.stats] - for i in range(len(samples_stats)): + num_samples = len(samples[Fields.stats]) + for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} - res_sample = self.compute_stats(this_sample, *args, **kwargs) + res_sample = self.compute_stats_single(this_sample, *args, + **kwargs) samples[Fields.stats][i] = res_sample[Fields.stats] if 'context' in kwargs and kwargs['context']: samples[Fields.context][i] = res_sample[Fields.context] @@ -318,10 +318,10 @@ def compute_stats_batched(self, samples, *args, **kwargs): return samples def process_batched(self, samples): - return map(lambda stat: self.process({Fields.stats: stat}), + return map(lambda stat: self.process_single({Fields.stats: stat}), samples[Fields.stats]) - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): """ Compute stats for the sample which is used as a metric to decide whether to filter this sample. @@ -333,7 +333,7 @@ def compute_stats(self, sample, context=False): """ raise NotImplementedError - def process(self, sample): + def process_single(self, sample): """ For sample level, sample --> Boolean. @@ -354,14 +354,14 @@ def run(self, dataset, *, exporter=None, tracer=None): num_proc=self.runtime_np(), batch_size=self.batch_size, desc='Adding new column for stats') - dataset = dataset.map(self.compute_stats_branch, + dataset = dataset.map(self.compute_stats, num_proc=self.runtime_np(), with_rank=self.use_cuda(), batch_size=self.batch_size, desc=self._name + '_compute_stats') if exporter and self.stats_export_path is not None: exporter.export_compute_stats(dataset, self.stats_export_path) - new_dataset = dataset.filter(self.process_branch, + new_dataset = dataset.filter(self.process, num_proc=self.runtime_np(), batch_size=self.batch_size, desc=self._name + '_process') diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index e6ea7cc7e..411749e25 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -51,7 +51,7 @@ def __init__(self, pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped', return_model=False) - def compute_stats(self, samples): + def compute_stats_batched(self, samples): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] @@ -79,7 +79,7 @@ def compute_stats(self, samples): return samples - def process(self, samples): + def process_batched(self, samples): ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \ else StatsKeys.alnum_ratio if isinstance(samples[Fields.stats], list): diff --git a/data_juicer/ops/filter/audio_duration_filter.py b/data_juicer/ops/filter/audio_duration_filter.py index cf70206f6..c73bdd6c4 100644 --- a/data_juicer/ops/filter/audio_duration_filter.py +++ b/data_juicer/ops/filter/audio_duration_filter.py @@ -46,7 +46,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.audio_duration in sample[Fields.stats]: return sample @@ -74,7 +74,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): audio_durations = sample[Fields.stats][StatsKeys.audio_duration] keep_bools = np.array([ self.min_duration <= duration <= self.max_duration diff --git a/data_juicer/ops/filter/audio_nmf_snr_filter.py b/data_juicer/ops/filter/audio_nmf_snr_filter.py index 1ae16c5f8..72f056383 100644 --- a/data_juicer/ops/filter/audio_nmf_snr_filter.py +++ b/data_juicer/ops/filter/audio_nmf_snr_filter.py @@ -96,7 +96,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.audio_nmf_snr in sample[Fields.stats]: return sample @@ -124,7 +124,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): audio_snrs = sample[Fields.stats][StatsKeys.audio_nmf_snr] keep_bools = np.array( [self.min_snr <= snr <= self.max_snr for snr in audio_snrs]) diff --git a/data_juicer/ops/filter/audio_size_filter.py b/data_juicer/ops/filter/audio_size_filter.py index dfeb0fcb7..e718041ef 100644 --- a/data_juicer/ops/filter/audio_size_filter.py +++ b/data_juicer/ops/filter/audio_size_filter.py @@ -40,7 +40,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.audio_sizes in sample[Fields.stats]: return sample @@ -58,7 +58,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): audio_sizes = sample[Fields.stats][StatsKeys.audio_sizes] keep_bools = np.array([ self.min_size <= audio_size <= self.max_size diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py index d2867b774..f29efc105 100644 --- a/data_juicer/ops/filter/average_line_length_filter.py +++ b/data_juicer/ops/filter/average_line_length_filter.py @@ -37,7 +37,7 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, samples, context=False): + def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] context_key = f'{InterVars.lines}' @@ -58,7 +58,7 @@ def compute_stats(self, samples, context=False): len(cur_text) / len(lines) if len(lines) != 0 else 0.0 return samples - def process(self, samples): + def process_batched(self, samples): if isinstance(samples[Fields.stats], list): return map( lambda stat: self.min_len <= stat[StatsKeys.avg_line_length] <= diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py index 965b368d6..ad0f543a0 100644 --- a/data_juicer/ops/filter/character_repetition_filter.py +++ b/data_juicer/ops/filter/character_repetition_filter.py @@ -41,7 +41,7 @@ def __init__(self, self.min_ratio = min_ratio self.max_ratio = max_ratio - def compute_stats(self, samples): + def compute_stats_batched(self, samples): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] @@ -78,7 +78,7 @@ def compute_stats(self, samples): return samples - def process(self, samples): + def process_batched(self, samples): if isinstance(samples[Fields.stats], list): return map( lambda stat: self.min_ratio <= stat[StatsKeys.char_rep_ratio] diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index 2966313fc..af97a0fc2 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -73,7 +73,7 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.flagged_words_ratio in sample[Fields.stats]: return sample @@ -120,6 +120,6 @@ def compute_stats(self, sample, context=False): StatsKeys.flagged_words_ratio] = flagged_words_ratio return sample - def process(self, sample): + def process_single(self, sample): return sample[Fields.stats][ StatsKeys.flagged_words_ratio] <= self.max_ratio diff --git a/data_juicer/ops/filter/image_aesthetics_filter.py b/data_juicer/ops/filter/image_aesthetics_filter.py index 8924aee8d..6cce19ebc 100644 --- a/data_juicer/ops/filter/image_aesthetics_filter.py +++ b/data_juicer/ops/filter/image_aesthetics_filter.py @@ -70,7 +70,7 @@ def __init__(self, self.need_normalized_by_ten = ('shunk031/aesthetics-predictor' in hf_scorer_model) - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_aesthetics_scores in sample[Fields.stats]: return sample @@ -107,7 +107,7 @@ def compute_stats(self, sample, rank=None, context=False): aesthetics_scores return sample - def process(self, sample): + def process_single(self, sample): aesthetics_scores = ( sample)[Fields.stats][StatsKeys.image_aesthetics_scores] if len(aesthetics_scores) <= 0: diff --git a/data_juicer/ops/filter/image_aspect_ratio_filter.py b/data_juicer/ops/filter/image_aspect_ratio_filter.py index 211a40eee..e069a1943 100644 --- a/data_juicer/ops/filter/image_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/image_aspect_ratio_filter.py @@ -40,7 +40,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.aspect_ratios in sample[Fields.stats]: return sample @@ -66,7 +66,7 @@ def compute_stats(self, sample, context=False): ] return sample - def process(self, sample): + def process_single(self, sample): aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios] keep_bools = np.array([ self.min_ratio <= aspect_ratio <= self.max_ratio diff --git a/data_juicer/ops/filter/image_face_ratio_filter.py b/data_juicer/ops/filter/image_face_ratio_filter.py index 76071f602..4abac8956 100644 --- a/data_juicer/ops/filter/image_face_ratio_filter.py +++ b/data_juicer/ops/filter/image_face_ratio_filter.py @@ -74,7 +74,7 @@ def __init__(self, self.model_key = prepare_model(model_type='opencv_classifier', model_path=cv_classifier) - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.face_ratios in sample[Fields.stats]: return sample @@ -112,7 +112,7 @@ def compute_stats(self, sample, context=False): ] return sample - def process(self, sample): + def process_single(self, sample): face_ratios = sample[Fields.stats][StatsKeys.face_ratios] if len(face_ratios) <= 0: return True diff --git a/data_juicer/ops/filter/image_nsfw_filter.py b/data_juicer/ops/filter/image_nsfw_filter.py index 50ac74a78..ff5f50eb6 100644 --- a/data_juicer/ops/filter/image_nsfw_filter.py +++ b/data_juicer/ops/filter/image_nsfw_filter.py @@ -54,7 +54,7 @@ def __init__(self, pretrained_model_name_or_path=hf_nsfw_model, trust_remote_code=trust_remote_code) - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_nsfw_score in sample[Fields.stats]: return sample @@ -84,7 +84,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): itm_scores = sample[Fields.stats][StatsKeys.image_nsfw_score] if len(itm_scores) <= 0: return True diff --git a/data_juicer/ops/filter/image_pair_similarity_filter.py b/data_juicer/ops/filter/image_pair_similarity_filter.py index dcb1b7059..d23351bf2 100644 --- a/data_juicer/ops/filter/image_pair_similarity_filter.py +++ b/data_juicer/ops/filter/image_pair_similarity_filter.py @@ -60,7 +60,7 @@ def __init__(self, pretrained_model_name_or_path=hf_clip, trust_remote_code=trust_remote_code) - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_pair_similarity in sample[Fields.stats]: @@ -97,7 +97,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): similarity = sample[Fields.stats][StatsKeys.image_pair_similarity] if len(similarity) <= 0: return True diff --git a/data_juicer/ops/filter/image_shape_filter.py b/data_juicer/ops/filter/image_shape_filter.py index c6de1b4bc..064929111 100644 --- a/data_juicer/ops/filter/image_shape_filter.py +++ b/data_juicer/ops/filter/image_shape_filter.py @@ -47,7 +47,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.image_width in sample[Fields.stats] \ and StatsKeys.image_height in sample[Fields.stats]: @@ -76,7 +76,7 @@ def compute_stats(self, sample, context=False): ] return sample - def process(self, sample): + def process_single(self, sample): ws = sample[Fields.stats][StatsKeys.image_width] hs = sample[Fields.stats][StatsKeys.image_height] if len(ws) <= 0: diff --git a/data_juicer/ops/filter/image_size_filter.py b/data_juicer/ops/filter/image_size_filter.py index 74716261e..f4ab8f760 100644 --- a/data_juicer/ops/filter/image_size_filter.py +++ b/data_juicer/ops/filter/image_size_filter.py @@ -40,7 +40,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.image_sizes in sample[Fields.stats]: return sample @@ -58,7 +58,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): image_sizes = sample[Fields.stats][StatsKeys.image_sizes] keep_bools = np.array([ self.min_size <= image_size <= self.max_size diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py index dda7bd153..3c0615dfe 100644 --- a/data_juicer/ops/filter/image_text_matching_filter.py +++ b/data_juicer/ops/filter/image_text_matching_filter.py @@ -74,7 +74,7 @@ def __init__(self, self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_text_matching_score in sample[Fields.stats]: return sample @@ -139,7 +139,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): itm_scores = sample[Fields.stats][StatsKeys.image_text_matching_score] if len(itm_scores) <= 0: return True diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index ca74441ca..510ac1f2b 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -74,7 +74,7 @@ def __init__(self, self.horizontal_flip = horizontal_flip self.vertical_flip = vertical_flip - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_text_similarity in sample[Fields.stats]: return sample @@ -136,7 +136,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): similarity = sample[Fields.stats][StatsKeys.image_text_similarity] if len(similarity) <= 0: return True diff --git a/data_juicer/ops/filter/image_watermark_filter.py b/data_juicer/ops/filter/image_watermark_filter.py index 4369dcafe..a823baaf0 100644 --- a/data_juicer/ops/filter/image_watermark_filter.py +++ b/data_juicer/ops/filter/image_watermark_filter.py @@ -58,7 +58,7 @@ def __init__(self, pretrained_model_name_or_path=hf_watermark_model, trust_remote_code=trust_remote_code) - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.image_watermark_prob in sample[Fields.stats]: return sample @@ -88,7 +88,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): itm_probs = sample[Fields.stats][StatsKeys.image_watermark_prob] if len(itm_probs) <= 0: return True diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py index 9da08f6a5..a3e9bd3c9 100644 --- a/data_juicer/ops/filter/language_id_score_filter.py +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -46,7 +46,7 @@ def __init__(self, self.min_score = min_score self.model_key = prepare_model(model_type='fasttext') - def compute_stats(self, sample): + def compute_stats_single(self, sample): # check if it's computed already if StatsKeys.lang in sample[ Fields.stats] and StatsKeys.lang_score in sample[Fields.stats]: @@ -67,7 +67,7 @@ def compute_stats(self, sample): return sample - def process(self, sample): + def process_single(self, sample): if self.lang: return sample[Fields.stats][StatsKeys.lang] in self.lang \ and sample[Fields.stats][StatsKeys.lang_score] >= \ diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py index 16c919406..ed67aee03 100644 --- a/data_juicer/ops/filter/maximum_line_length_filter.py +++ b/data_juicer/ops/filter/maximum_line_length_filter.py @@ -37,7 +37,7 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, samples, context=False): + def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] context_key = f'{InterVars.lines}' @@ -59,7 +59,7 @@ def compute_stats(self, samples, context=False): return samples - def process(self, samples): + def process_batched(self, samples): if isinstance(samples[Fields.stats], list): return map( lambda stat: self.min_len <= stat[StatsKeys.max_line_length] <= diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index ab031157b..ed56b8ab1 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -47,7 +47,7 @@ def __init__(self, lang=lang) self.kl_model_key = prepare_model(model_type='kenlm', lang=lang) - def compute_stats(self, samples, context=False): + def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.sp_model_key}' @@ -79,7 +79,7 @@ def compute_stats(self, samples, context=False): return samples - def process(self, samples): + def process_batched(self, samples): if isinstance(samples[Fields.stats], list): return map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl, samples[Fields.stats]) diff --git a/data_juicer/ops/filter/phrase_grounding_recall_filter.py b/data_juicer/ops/filter/phrase_grounding_recall_filter.py index 9a9ba65dd..8592fe0d4 100644 --- a/data_juicer/ops/filter/phrase_grounding_recall_filter.py +++ b/data_juicer/ops/filter/phrase_grounding_recall_filter.py @@ -142,7 +142,7 @@ def __init__(self, for nltk_data_pkg in requires_nltk_data: nltk.download(nltk_data_pkg) - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.phrase_grounding_recall in sample[Fields.stats]: return sample @@ -256,7 +256,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample): + def process_single(self, sample): recalls = sample[Fields.stats][StatsKeys.phrase_grounding_recall] if len(recalls) <= 0: return True diff --git a/data_juicer/ops/filter/special_characters_filter.py b/data_juicer/ops/filter/special_characters_filter.py index 59fa61f52..a5b64aade 100644 --- a/data_juicer/ops/filter/special_characters_filter.py +++ b/data_juicer/ops/filter/special_characters_filter.py @@ -36,7 +36,7 @@ def __init__(self, self.min_ratio = min_ratio self.max_ratio = max_ratio - def compute_stats(self, samples): + def compute_stats_batched(self, samples): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] @@ -52,7 +52,7 @@ def compute_stats(self, samples): return samples - def process(self, samples): + def process_batched(self, samples): if isinstance(samples[Fields.stats], list): return map( lambda stat: self.min_ratio <= stat[ diff --git a/data_juicer/ops/filter/specified_field_filter.py b/data_juicer/ops/filter/specified_field_filter.py index 7f79a98b8..86aff2426 100644 --- a/data_juicer/ops/filter/specified_field_filter.py +++ b/data_juicer/ops/filter/specified_field_filter.py @@ -33,10 +33,10 @@ def __init__(self, self.field_key = field_key self.target_value = target_value - def compute_stats(self, sample): + def compute_stats_single(self, sample): return sample - def process(self, sample): + def process_single(self, sample): if not (self.field_key and self.target_value): return True diff --git a/data_juicer/ops/filter/specified_numeric_field_filter.py b/data_juicer/ops/filter/specified_numeric_field_filter.py index 00cb4226e..693be3392 100644 --- a/data_juicer/ops/filter/specified_numeric_field_filter.py +++ b/data_juicer/ops/filter/specified_numeric_field_filter.py @@ -49,10 +49,10 @@ def __init__(self, self.min_value = min_value self.max_value = max_value - def compute_stats(self, sample): + def compute_stats_single(self, sample): return sample - def process(self, sample): + def process_single(self, sample): if not self.field_key: return True diff --git a/data_juicer/ops/filter/stopwords_filter.py b/data_juicer/ops/filter/stopwords_filter.py index 1d9f59b7b..3b7ed11bd 100644 --- a/data_juicer/ops/filter/stopwords_filter.py +++ b/data_juicer/ops/filter/stopwords_filter.py @@ -74,7 +74,7 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.stopwords_ratio in sample[Fields.stats]: return sample @@ -121,6 +121,6 @@ def compute_stats(self, sample, context=False): sample[Fields.stats][StatsKeys.stopwords_ratio] = stopwords_ratio return sample - def process(self, sample): + def process_single(self, sample): return sample[Fields.stats][ StatsKeys.stopwords_ratio] >= self.min_ratio diff --git a/data_juicer/ops/filter/suffix_filter.py b/data_juicer/ops/filter/suffix_filter.py index 52a833691..ea7868399 100644 --- a/data_juicer/ops/filter/suffix_filter.py +++ b/data_juicer/ops/filter/suffix_filter.py @@ -26,10 +26,10 @@ def __init__(self, suffixes: Union[str, List[str]] = [], *args, **kwargs): else: self.suffixes = suffixes - def compute_stats(self, sample): + def compute_stats_single(self, sample): return sample - def process(self, sample): + def process_single(self, sample): if self.suffixes: if sample[Fields.suffix] in self.suffixes: return True diff --git a/data_juicer/ops/filter/text_action_filter.py b/data_juicer/ops/filter/text_action_filter.py index 44c67920d..22cf8e145 100644 --- a/data_juicer/ops/filter/text_action_filter.py +++ b/data_juicer/ops/filter/text_action_filter.py @@ -40,7 +40,7 @@ def __init__(self, self.action_tags = ['VV', 'VB', 'VBP', 'VBZ', 'VBD', 'VBG', 'VBN'] self.min_action_num = min_action_num - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.num_action in sample[Fields.stats]: return sample @@ -59,7 +59,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): num_action = sample[Fields.stats][StatsKeys.num_action] if self.min_action_num <= num_action: return True diff --git a/data_juicer/ops/filter/text_entity_dependency_filter.py b/data_juicer/ops/filter/text_entity_dependency_filter.py index 6e4ec9f36..c45e645b7 100644 --- a/data_juicer/ops/filter/text_entity_dependency_filter.py +++ b/data_juicer/ops/filter/text_entity_dependency_filter.py @@ -51,7 +51,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.num_dependency_edges in sample[Fields.stats]: return sample @@ -86,7 +86,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): num_dependency_edges = sample[Fields.stats][ StatsKeys.num_dependency_edges] keep_bools = np.array([ diff --git a/data_juicer/ops/filter/token_num_filter.py b/data_juicer/ops/filter/token_num_filter.py index de3349315..38a6c71d6 100644 --- a/data_juicer/ops/filter/token_num_filter.py +++ b/data_juicer/ops/filter/token_num_filter.py @@ -47,7 +47,7 @@ def __init__(self, pretrained_model_name_or_path=hf_tokenizer, return_model=False) - def compute_stats(self, sample): + def compute_stats_single(self, sample): # check if it's computed already if StatsKeys.num_token in sample[Fields.stats]: return sample @@ -59,7 +59,7 @@ def compute_stats(self, sample): sample[Fields.stats][StatsKeys.num_token] = len(tokens) return sample - def process(self, sample): + def process_single(self, sample): if self.min_num <= sample[Fields.stats][ StatsKeys.num_token] <= self.max_num: return True diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index 31c242473..b01773134 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -112,7 +112,7 @@ def __init__(self, ('' if frame_sampling_method == 'all_keyframes' else f'-{frame_num}') - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.video_frames_aesthetics_score in sample[Fields.stats]: return sample @@ -187,7 +187,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample): + def process_single(self, sample): aesthetics_scores = ( sample)[Fields.stats][StatsKeys.video_frames_aesthetics_score] if len(aesthetics_scores) <= 0: diff --git a/data_juicer/ops/filter/video_aspect_ratio_filter.py b/data_juicer/ops/filter/video_aspect_ratio_filter.py index 49f684ebd..a06f54687 100644 --- a/data_juicer/ops/filter/video_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/video_aspect_ratio_filter.py @@ -45,7 +45,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.video_aspect_ratios in sample[Fields.stats]: return sample @@ -76,7 +76,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): video_aspect_ratios = sample[Fields.stats][ StatsKeys.video_aspect_ratios] diff --git a/data_juicer/ops/filter/video_duration_filter.py b/data_juicer/ops/filter/video_duration_filter.py index 1cccf87c7..0754763c7 100644 --- a/data_juicer/ops/filter/video_duration_filter.py +++ b/data_juicer/ops/filter/video_duration_filter.py @@ -46,7 +46,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.video_duration in sample[Fields.stats]: return sample @@ -77,7 +77,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): video_durations = sample[Fields.stats][StatsKeys.video_duration] keep_bools = np.array([ self.min_duration <= duration <= self.max_duration diff --git a/data_juicer/ops/filter/video_frames_text_similarity_filter.py b/data_juicer/ops/filter/video_frames_text_similarity_filter.py index ddcbff1e7..f5d6ea211 100644 --- a/data_juicer/ops/filter/video_frames_text_similarity_filter.py +++ b/data_juicer/ops/filter/video_frames_text_similarity_filter.py @@ -107,7 +107,7 @@ def __init__(self, ('' if frame_sampling_method == 'all_keyframes' else f'-{frame_num}') - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.video_frames_text_similarity in sample[Fields.stats]: return sample @@ -201,7 +201,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): similarity = sample[Fields.stats][ StatsKeys.video_frames_text_similarity] if len(similarity) <= 0: diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py index e8e63f052..42b14843f 100644 --- a/data_juicer/ops/filter/video_motion_score_filter.py +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -105,7 +105,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.video_motion_score in sample[Fields.stats]: return sample @@ -182,7 +182,7 @@ def compute_stats(self, sample, context=False): ] return sample - def process(self, sample): + def process_single(self, sample): video_motion_scores = sample[Fields.stats][ StatsKeys.video_motion_score] diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py index a96151f3e..555b8466f 100644 --- a/data_juicer/ops/filter/video_nsfw_filter.py +++ b/data_juicer/ops/filter/video_nsfw_filter.py @@ -93,7 +93,7 @@ def __init__(self, ('' if frame_sampling_method == 'all_keyframes' else f'-{frame_num}') - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.video_nsfw_score in sample[Fields.stats]: return sample @@ -163,7 +163,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): itm_scores = sample[Fields.stats][StatsKeys.video_nsfw_score] if len(itm_scores) <= 0: return True diff --git a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py index a36214fbc..214e28944 100644 --- a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py +++ b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py @@ -101,7 +101,7 @@ def get_reader(self, rank): self.reader.device = device return self.reader - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.video_ocr_area_ratio in sample[Fields.stats]: return sample @@ -182,7 +182,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample): + def process_single(self, sample): video_ocr_area_ratios = sample[Fields.stats][ StatsKeys.video_ocr_area_ratio] keep_bools = np.array([ diff --git a/data_juicer/ops/filter/video_resolution_filter.py b/data_juicer/ops/filter/video_resolution_filter.py index 61e5d13cd..806190908 100644 --- a/data_juicer/ops/filter/video_resolution_filter.py +++ b/data_juicer/ops/filter/video_resolution_filter.py @@ -50,7 +50,7 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.video_width in sample[Fields.stats] \ and StatsKeys.video_height in sample[Fields.stats]: @@ -95,7 +95,7 @@ def compute_stats(self, sample, context=False): return sample - def process(self, sample): + def process_single(self, sample): ws = sample[Fields.stats][StatsKeys.video_width] hs = sample[Fields.stats][StatsKeys.video_height] keep_bools = np.array([ diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index f85cfaa54..543419a1b 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -87,13 +87,13 @@ def __init__(self, tag_field_name=self.tag_field_name, ) - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): sample = self.tagging_producer.process(sample, rank, context) return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): video_tags = sample[self.tag_field_name] if len(video_tags) <= 0: return True diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py index c5ddfc8b7..b8413139f 100644 --- a/data_juicer/ops/filter/video_watermark_filter.py +++ b/data_juicer/ops/filter/video_watermark_filter.py @@ -96,7 +96,7 @@ def __init__(self, ('' if frame_sampling_method == 'all_keyframes' else f'-{frame_num}') - def compute_stats(self, sample, rank=None, context=False): + def compute_stats_single(self, sample, rank=None, context=False): # check if it's computed already if StatsKeys.video_watermark_prob in sample[Fields.stats]: return sample @@ -164,7 +164,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): itm_probs = sample[Fields.stats][StatsKeys.video_watermark_prob] if len(itm_probs) <= 0: return True diff --git a/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py b/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py index b6434c0f4..f32f0e8ff 100644 --- a/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py +++ b/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py @@ -50,7 +50,7 @@ def __init__( self.capture_stderr = capture_stderr self.overwrite_output = overwrite_output - def process(self, sample): + def process_single(self, sample): # there is no audio in this sample if self.audio_key not in sample or not sample[self.audio_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/chinese_convert_mapper.py b/data_juicer/ops/mapper/chinese_convert_mapper.py index e18fa0afc..2ce453e90 100644 --- a/data_juicer/ops/mapper/chinese_convert_mapper.py +++ b/data_juicer/ops/mapper/chinese_convert_mapper.py @@ -84,7 +84,7 @@ def __init__(self, mode: str = 's2t', *args, **kwargs): self.mode = mode prepare_converter(self.mode) - def process(self, samples): + def process_batched(self, samples): prepare_converter(self.mode) samples[self.text_key] = [ diff --git a/data_juicer/ops/mapper/clean_copyright_mapper.py b/data_juicer/ops/mapper/clean_copyright_mapper.py index 8908d33e9..4247f8ac4 100644 --- a/data_juicer/ops/mapper/clean_copyright_mapper.py +++ b/data_juicer/ops/mapper/clean_copyright_mapper.py @@ -54,7 +54,7 @@ def _process_single_sample(self, sample): sample = '\n'.join(lines[skip:]) return sample - def process(self, samples): + def process_batched(self, samples): samples[self.text_key] = [ self._process_single_sample(text) for text in samples[self.text_key] diff --git a/data_juicer/ops/mapper/clean_email_mapper.py b/data_juicer/ops/mapper/clean_email_mapper.py index 90dcca60f..572a0370a 100644 --- a/data_juicer/ops/mapper/clean_email_mapper.py +++ b/data_juicer/ops/mapper/clean_email_mapper.py @@ -36,7 +36,7 @@ def __init__(self, self.repl = repl - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): if not re.search(self.pattern, text, flags=re.DOTALL): continue diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py index 477c46846..197c20fe6 100644 --- a/data_juicer/ops/mapper/clean_html_mapper.py +++ b/data_juicer/ops/mapper/clean_html_mapper.py @@ -27,7 +27,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) AUTOINSTALL.check(['selectolax']) - def process(self, samples): + def process_batched(self, samples): def _clean_html(raw_html): raw_html = raw_html.replace('
  • ', '\n*') diff --git a/data_juicer/ops/mapper/clean_ip_mapper.py b/data_juicer/ops/mapper/clean_ip_mapper.py index 709037ddd..193995760 100644 --- a/data_juicer/ops/mapper/clean_ip_mapper.py +++ b/data_juicer/ops/mapper/clean_ip_mapper.py @@ -40,7 +40,7 @@ def __init__(self, self.pattern = pattern[2:-1] self.repl = repl - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): if not re.search(self.pattern, text, flags=re.DOTALL): continue diff --git a/data_juicer/ops/mapper/clean_links_mapper.py b/data_juicer/ops/mapper/clean_links_mapper.py index f08abc78f..70efb36ca 100644 --- a/data_juicer/ops/mapper/clean_links_mapper.py +++ b/data_juicer/ops/mapper/clean_links_mapper.py @@ -46,7 +46,7 @@ def __init__(self, self.pattern = pattern[2:-1] self.repl = repl - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): if not re.search(self.pattern, text, flags=re.DOTALL): continue diff --git a/data_juicer/ops/mapper/expand_macro_mapper.py b/data_juicer/ops/mapper/expand_macro_mapper.py index b83455103..7bc3f25ca 100644 --- a/data_juicer/ops/mapper/expand_macro_mapper.py +++ b/data_juicer/ops/mapper/expand_macro_mapper.py @@ -57,7 +57,7 @@ def _build_non_arg_macros_dict(self, file_content): macros[macro_name] = macro_val return macros - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): non_arg_macros = self._build_non_arg_macros_dict(text) diff --git a/data_juicer/ops/mapper/extract_qa_mapper.py b/data_juicer/ops/mapper/extract_qa_mapper.py index 8a41efeb4..23d99c1af 100644 --- a/data_juicer/ops/mapper/extract_qa_mapper.py +++ b/data_juicer/ops/mapper/extract_qa_mapper.py @@ -133,7 +133,7 @@ def _extract_qa(self, output): return qa_list - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): model, processor = get_model(self.model_key, rank, self.use_cuda()) if self.enable_vllm: diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py index e2323c3b9..1c1124977 100644 --- a/data_juicer/ops/mapper/fix_unicode_mapper.py +++ b/data_juicer/ops/mapper/fix_unicode_mapper.py @@ -35,7 +35,7 @@ def __init__(self, normalization: str = None, *args, **kwargs): 'supported. Can only be one of ' '["NFC", "NFKC", "NFD", "NFKD"]') - def process(self, samples): + def process_batched(self, samples): samples[self.text_key] = [ ftfy.fix_text(text, normalization=self.normalization) for text in samples[self.text_key] diff --git a/data_juicer/ops/mapper/generate_instruction_mapper.py b/data_juicer/ops/mapper/generate_instruction_mapper.py index 9fafa94e3..91547b245 100644 --- a/data_juicer/ops/mapper/generate_instruction_mapper.py +++ b/data_juicer/ops/mapper/generate_instruction_mapper.py @@ -246,7 +246,7 @@ def max_rouge_l_score(self, reference, candidates): max_score = rouge_l_score return max_score - def process(self, sample=None, rank=None): + def process_single(self, sample=None, rank=None): model, processor = get_model(self.model_key, rank=rank) random_qa_samples = random.sample(self.seed_qa_samples, diff --git a/data_juicer/ops/mapper/image_blur_mapper.py b/data_juicer/ops/mapper/image_blur_mapper.py index 536f952a9..1b1379b7d 100644 --- a/data_juicer/ops/mapper/image_blur_mapper.py +++ b/data_juicer/ops/mapper/image_blur_mapper.py @@ -53,7 +53,7 @@ def __init__(self, else: self.blur = ImageFilter.GaussianBlur(radius) - def process(self, sample, context=False): + def process_single(self, sample, context=False): # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py index 1500a074a..6d4d42cfb 100644 --- a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py @@ -248,7 +248,7 @@ def _process_single_sample(self, sample): return [generated_sample] - def process(self, samples): + def process_batched(self, samples): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] for i in range(len(samples[self.text_key])): diff --git a/data_juicer/ops/mapper/image_captioning_mapper.py b/data_juicer/ops/mapper/image_captioning_mapper.py index 5f04fa97f..5e569022f 100644 --- a/data_juicer/ops/mapper/image_captioning_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_mapper.py @@ -268,7 +268,7 @@ def _reduce_captions_per_image(self, chunk, generated_text_candidates_single_chunk[max_index]) return new_generated_text_per_chunk - def process(self, samples, rank=None): + def process_batched(self, samples, rank=None): """ Note: This is a batched_OP, whose input and output type are diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py index a69d8ac6a..3c958b47f 100644 --- a/data_juicer/ops/mapper/image_diffusion_mapper.py +++ b/data_juicer/ops/mapper/image_diffusion_mapper.py @@ -207,7 +207,7 @@ def _process_single_sample(self, ori_sample, rank=None, context=False): return generated_samples - def process(self, samples, rank=None, context=False): + def process_batched(self, samples, rank=None, context=False): """ Note: This is a batched_OP, whose the input and output type are diff --git a/data_juicer/ops/mapper/image_face_blur_mapper.py b/data_juicer/ops/mapper/image_face_blur_mapper.py index e3d37e21b..23b37d126 100644 --- a/data_juicer/ops/mapper/image_face_blur_mapper.py +++ b/data_juicer/ops/mapper/image_face_blur_mapper.py @@ -82,7 +82,7 @@ def __init__(self, self.model_key = prepare_model(model_type='opencv_classifier', model_path=cv_classifier) - def process(self, sample, context=False): + def process_single(self, sample, context=False): # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py index 0bd2b89e2..a9f4b67cb 100644 --- a/data_juicer/ops/mapper/image_tagging_mapper.py +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -51,7 +51,7 @@ def __init__(self, self.transform = get_transform(image_size=384) self.tag_field_name = tag_field_name - def process(self, sample, rank=None, context=False): + def process_single(self, sample, rank=None, context=False): # check if it's generated already if self.tag_field_name in sample: return sample diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index 9a253c9c2..de2203dde 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -124,7 +124,7 @@ def __init__(self, else: self.aug = aug_pipeline - def process(self, samples): + def process_batched(self, samples): # no augmentation methods are opened if len(self.aug) == 0: if self.keep_original_sample: diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index adc718beb..77157173e 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -130,7 +130,7 @@ def __init__(self, self.aug_pipeline.append( nlpcda.EquivalentChar(create_num=create_num)) - def process(self, samples): + def process_batched(self, samples): # no augmentation methods are opened if len(self.aug_pipeline) == 0: if self.keep_original_sample: diff --git a/data_juicer/ops/mapper/optimize_instruction_mapper.py b/data_juicer/ops/mapper/optimize_instruction_mapper.py index a9ec0564c..34e2affbf 100644 --- a/data_juicer/ops/mapper/optimize_instruction_mapper.py +++ b/data_juicer/ops/mapper/optimize_instruction_mapper.py @@ -93,7 +93,7 @@ def __init__(self, trust_remote_code=trust_remote_code) self.sampling_params = sampling_params - def process(self, sample=None, rank=None): + def process_single(self, sample=None, rank=None): model, processor = get_model(self.model_key, rank=rank) messages = [{ diff --git a/data_juicer/ops/mapper/punctuation_normalization_mapper.py b/data_juicer/ops/mapper/punctuation_normalization_mapper.py index 18aa12c56..6ad7eb8c0 100644 --- a/data_juicer/ops/mapper/punctuation_normalization_mapper.py +++ b/data_juicer/ops/mapper/punctuation_normalization_mapper.py @@ -57,7 +57,7 @@ def __init__(self, *args, **kwargs): '►': '-', } - def process(self, samples): + def process_batched(self, samples): samples[self.text_key] = [ ''.join([self.punctuation_unicode.get(c, c) for c in text]) for text in samples[self.text_key] diff --git a/data_juicer/ops/mapper/remove_bibliography_mapper.py b/data_juicer/ops/mapper/remove_bibliography_mapper.py index 1eecd66d2..e876eb3e0 100644 --- a/data_juicer/ops/mapper/remove_bibliography_mapper.py +++ b/data_juicer/ops/mapper/remove_bibliography_mapper.py @@ -29,7 +29,7 @@ def __init__(self, *args, **kwargs): self.pattern += r'\\bibliography\{.*\}' self.pattern += r').*$' - def process(self, samples): + def process_batched(self, samples): samples[self.text_key] = [ re.sub(pattern=self.pattern, repl=r'', diff --git a/data_juicer/ops/mapper/remove_comments_mapper.py b/data_juicer/ops/mapper/remove_comments_mapper.py index 09fe4e5ef..431c42c1e 100644 --- a/data_juicer/ops/mapper/remove_comments_mapper.py +++ b/data_juicer/ops/mapper/remove_comments_mapper.py @@ -39,7 +39,7 @@ def __init__(self, self.inline = inline self.multiline = multiline - def process(self, samples): + def process_batched(self, samples): # TODO: remove different comments by sample type for idx, text in enumerate(samples[self.text_key]): diff --git a/data_juicer/ops/mapper/remove_header_mapper.py b/data_juicer/ops/mapper/remove_header_mapper.py index bb967e929..b668c2fab 100644 --- a/data_juicer/ops/mapper/remove_header_mapper.py +++ b/data_juicer/ops/mapper/remove_header_mapper.py @@ -36,7 +36,7 @@ def __init__(self, drop_no_head: bool = True, *args, **kwargs): self.drop_no_head = drop_no_head - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): if not re.search(self.pattern, text, flags=re.DOTALL): if self.drop_no_head: diff --git a/data_juicer/ops/mapper/remove_long_words_mapper.py b/data_juicer/ops/mapper/remove_long_words_mapper.py index 5aea47516..e59f3a867 100644 --- a/data_juicer/ops/mapper/remove_long_words_mapper.py +++ b/data_juicer/ops/mapper/remove_long_words_mapper.py @@ -43,7 +43,7 @@ def should_keep_long_word(self, word): else: return False - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): sentences = split_on_newline_tab_whitespace(text) sentences = [[[ diff --git a/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py b/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py index 371697efc..874b35235 100644 --- a/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py +++ b/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py @@ -35,7 +35,7 @@ def __init__(self, else: self.pattern += u']' - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): if not re.search(self.pattern, text, flags=re.DOTALL): continue diff --git a/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py b/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py index add0a719e..d3a5349a5 100644 --- a/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py +++ b/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py @@ -45,7 +45,7 @@ def __init__(self, self.remove_regex = re.compile(r'[^a-zA-Z0-9\u4e00-\u9fa5\n\t ]' ) if ignore_special_character else None - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): lines = [e for e in text.split('\n')] new_lines = [] diff --git a/data_juicer/ops/mapper/remove_specific_chars_mapper.py b/data_juicer/ops/mapper/remove_specific_chars_mapper.py index 78ca55e62..ae3281e86 100644 --- a/data_juicer/ops/mapper/remove_specific_chars_mapper.py +++ b/data_juicer/ops/mapper/remove_specific_chars_mapper.py @@ -30,7 +30,7 @@ def __init__(self, else: self.pattern = None - def process(self, samples): + def process_batched(self, samples): if self.pattern is None: return samples diff --git a/data_juicer/ops/mapper/remove_table_text_mapper.py b/data_juicer/ops/mapper/remove_table_text_mapper.py index 8273c8dab..ff2b07a4f 100644 --- a/data_juicer/ops/mapper/remove_table_text_mapper.py +++ b/data_juicer/ops/mapper/remove_table_text_mapper.py @@ -34,7 +34,7 @@ def __init__(self, self.max_col = max_col self.pattern = r'(?<=\n)((\S+?)([ |\t](\S+?)){%d}\n+){2,}' - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): for idx in range(self.min_col - 1, self.max_col): pattern = re.compile(self.pattern % idx) diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py index eea948097..de93b1727 100644 --- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -47,7 +47,7 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings): should_keep = all([(i_substr not in word) for i_substr in substrings]) return should_keep - def process(self, samples): + def process_batched(self, samples): for idx, text in enumerate(samples[self.text_key]): if self.tokenization: tokenizer = get_model(self.model_key) diff --git a/data_juicer/ops/mapper/replace_content_mapper.py b/data_juicer/ops/mapper/replace_content_mapper.py index 324cc6357..9cd552ad8 100644 --- a/data_juicer/ops/mapper/replace_content_mapper.py +++ b/data_juicer/ops/mapper/replace_content_mapper.py @@ -44,7 +44,7 @@ def _prepare_pattern(self, pattern: str) -> re.Pattern: pattern = pattern[2:-1] return re.compile(pattern, flags=re.DOTALL) - def process(self, samples): + def process_batched(self, samples): if self.pattern is None: return samples diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py index c71479de4..ea0608894 100644 --- a/data_juicer/ops/mapper/sentence_split_mapper.py +++ b/data_juicer/ops/mapper/sentence_split_mapper.py @@ -25,7 +25,7 @@ def __init__(self, lang: str = 'en', *args, **kwargs): self.lang = lang self.model_key = prepare_model(model_type='nltk', lang=lang) - def process(self, samples): + def process_batched(self, samples): nltk_model = get_model(self.model_key) diff --git a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py index 9eaa960e7..3543e26fb 100644 --- a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py @@ -115,7 +115,7 @@ def _process_single_sample(self, sample, rank=None): captioned_sample[self.video_key] = left_video_keys return [captioned_sample] - def process(self, samples, rank=None): + def process_batched(self, samples, rank=None): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] for i in range(len(samples[self.text_key])): diff --git a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py index 08eee2add..d5381dc3a 100644 --- a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py @@ -324,7 +324,7 @@ def _reduce_captions(self, chunk, generated_text_candidates_single_chunk): generated_text_candidates_single_chunk[max_index]) return generated_text_per_chunk - def process(self, samples, rank=None, context=False): + def process_batched(self, samples, rank=None, context=False): """ :param samples: :return: diff --git a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py index 02cf781a0..e73f4b756 100644 --- a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py @@ -234,7 +234,7 @@ def _process_single_sample(self, sample, rank=None): captioned_sample[self.text_key] = captioned_texts return [captioned_sample] - def process(self, samples, rank=None): + def process_batched(self, samples, rank=None): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] for i in range(len(samples[self.text_key])): diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py index ebb26c53b..06ab5f84b 100644 --- a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py @@ -331,7 +331,7 @@ def _reduce_captions(self, chunk, generated_text_candidates_single_chunk): generated_text_candidates_single_chunk[max_index]) return generated_text_per_chunk - def process(self, samples, rank=None, context=False): + def process_batched(self, samples, rank=None, context=False): """ :param samples: :return: diff --git a/data_juicer/ops/mapper/video_face_blur_mapper.py b/data_juicer/ops/mapper/video_face_blur_mapper.py index f30917536..3fb38a08f 100644 --- a/data_juicer/ops/mapper/video_face_blur_mapper.py +++ b/data_juicer/ops/mapper/video_face_blur_mapper.py @@ -82,7 +82,7 @@ def __init__(self, self.model_key = prepare_model(model_type='opencv_classifier', model_path=cv_classifier) - def process(self, sample, context=False): + def process_single(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py b/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py index c711a6ae8..e1216d67f 100644 --- a/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py +++ b/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py @@ -50,7 +50,7 @@ def __init__( self.capture_stderr = capture_stderr self.overwrite_output = overwrite_output - def process(self, sample): + def process_single(self, sample): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/video_remove_watermark_mapper.py b/data_juicer/ops/mapper/video_remove_watermark_mapper.py index 2c3166e8b..4e4e7e9dc 100644 --- a/data_juicer/ops/mapper/video_remove_watermark_mapper.py +++ b/data_juicer/ops/mapper/video_remove_watermark_mapper.py @@ -202,7 +202,7 @@ def _clean_watermark(self, frame, watermark_mask): new_np_frame = cv2.inpaint(np_frame, watermark_mask, 3, cv2.INPAINT_NS) return av.VideoFrame.from_ndarray(new_np_frame, format='bgr24') - def process(self, sample, context=False): + def process_single(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py index 99192c9c1..253e1af79 100644 --- a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py +++ b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py @@ -102,7 +102,7 @@ def __init__( self.max_ratio = Fraction(str(max_ratio).replace(':', '/')) self.strategy = strategy - def process(self, sample): + def process_single(self, sample): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/video_resize_resolution_mapper.py b/data_juicer/ops/mapper/video_resize_resolution_mapper.py index 574dd04d6..89e0e7054 100644 --- a/data_juicer/ops/mapper/video_resize_resolution_mapper.py +++ b/data_juicer/ops/mapper/video_resize_resolution_mapper.py @@ -85,7 +85,7 @@ def __init__(self, self.force_original_aspect_ratio = force_original_aspect_ratio self.force_divisible_by = force_divisible_by - def process(self, sample, context=False): + def process_single(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/video_split_by_duration_mapper.py b/data_juicer/ops/mapper/video_split_by_duration_mapper.py index 0a41d3240..dcfcdfc12 100644 --- a/data_juicer/ops/mapper/video_split_by_duration_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_duration_mapper.py @@ -143,7 +143,7 @@ def _process_single_sample(self, sample): split_sample[self.video_key] = split_video_keys return [split_sample] - def process(self, samples): + def process_batched(self, samples): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] for i in range(len(samples[self.text_key])): diff --git a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py index 0a4a7c593..fdfb73f93 100644 --- a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py @@ -125,7 +125,7 @@ def _process_single_sample(self, sample): split_sample[self.video_key] = split_video_keys return [split_sample] - def process(self, samples): + def process_batched(self, samples): # reconstruct samples from "dict of lists" to "list of dicts" reconstructed_samples = [] for i in range(len(samples[self.text_key])): diff --git a/data_juicer/ops/mapper/video_split_by_scene_mapper.py b/data_juicer/ops/mapper/video_split_by_scene_mapper.py index 7ce921e09..36a428ae7 100644 --- a/data_juicer/ops/mapper/video_split_by_scene_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_scene_mapper.py @@ -81,7 +81,7 @@ def __init__(self, for key in avaliable_kwargs if key in kwargs } - def process(self, sample, context=False): + def process_single(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: sample[Fields.source_file] = [] diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 164fc46bd..1b6aea7c2 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -47,7 +47,7 @@ def __init__(self, self.tag_field_name = tag_field_name - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): # check if it's generated already if self.tag_field_name in sample: return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 99cb96aa4..5fc5c063f 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -74,7 +74,7 @@ def __init__(self, self.tag_field_name = tag_field_name - def process(self, sample, rank=None, context=False): + def process_single(self, sample, rank=None, context=False): # check if it's generated already if self.tag_field_name in sample: return sample From 3b470f24fc5865fe6167fddeae4969ee6f0ba6d1 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Thu, 17 Oct 2024 17:54:30 +0800 Subject: [PATCH 10/13] * update docs for this modification --- docs/DeveloperGuide.md | 24 +++++++++++++----------- docs/DeveloperGuide_ZH.md | 22 ++++++++++++---------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index 42fa5d09e..3526daeb6 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -51,7 +51,8 @@ class StatsKeys(object): ``` 2. Create a new OP file `text_length_filter.py` in the corresponding `data_juicer/ops/filter/` directory as follows. - - Because it's a Filter OP, so the new OP needs to inherit from the basic `Filter` class in the `base_op.py`, and be decorated with `OPERATORS` to register itself automatically. + - It's a Filter OP, so the new OP needs to inherit from the basic `Filter` class in the `base_op.py`, and be decorated with `OPERATORS` to register itself automatically. + - For convenience, we can implement the core functions `compute_stats_single` and `process_single` in a single-sample way, whose input and output are a single sample dictionary. If you are very familiar with batched processing in Data-Juicer, you can also implement the batched version directly by overwriting the `compute_stats_batched` and `process_batched` functions, which will be slightly faster than single-sample version. Their input and output are a column-wise dict with multiple samples. ```python import sys @@ -89,7 +90,7 @@ class StatsKeys(object): self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample): + def compute_stats_single(self, sample): # check if it's computed already if StatsKeys.text_len in sample[Fields.stats]: return sample @@ -97,14 +98,14 @@ class StatsKeys(object): sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) return sample - def process(self, sample): + def process_single(self, sample): if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len: return True else: return False ``` - - If Hugging Face models are used within an operator, you might want to leverage GPU acceleration. To achieve this, declare `_accelerator = 'cuda'` in the constructor, and ensure that `compute_stats` and `process` methods accept an additional positional argument `rank`. + - If Hugging Face models are used within an operator, you might want to leverage GPU acceleration. To achieve this, declare `_accelerator = 'cuda'` in the constructor, and ensure that `compute_stats_single/batched` and `process_single/batched` methods accept an additional positional argument `rank`. ```python # ... (same as above) @@ -121,14 +122,15 @@ class StatsKeys(object): **kwargs): # ... (same as above) - def compute_stats(self, sample, rank=None): + def compute_stats_single(self, sample, rank=None): # ... (same as above) - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): # ... (same as above) ``` - - If the operator processes data in batches rather than a single sample, it is necessary to declare `_batched_op = True`. + - If the operator processes data in batches rather than a single sample, or you want to enable batched processing, it is necessary to declare `_batched_op = True`. + - For the original `compute_stats_single` and `process_single` functions, you can keep it still and Data-Juicer will call the default batched version to call the single version to support batched processing. Or you can implement your batched version in a more efficient way. ```python # ... (import some other libraries) OP_NAME = 'image_diffusion_mapper' @@ -143,7 +145,7 @@ class StatsKeys(object): **kwargs): super().__init__(*args, **kwargs) - def process(self, samples): + def process_batched(self, samples): # ... (some codes) ``` @@ -162,7 +164,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample): + def process_single(self, sample): # ... (some codes) # captions[index] is the prompt for diffusion model related_parameters = self.add_parameters( @@ -186,7 +188,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample): + def process_single(self, sample): # ... (some codes) split_video_path = transfer_filename( original_video_path, OP_NAME, **self._init_parameters) @@ -396,7 +398,7 @@ class PerplexityFilter(Filter): AUTOINSTALL.check(['sentencepiece', 'kenlm']) # ... (some codes) - def process(self, sample): + def process_single(self, sample): # ... (some codes) ``` diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 1574287f6..20ba0e261 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -47,6 +47,7 @@ class StatsKeys(object): 2. 在 `data_juicer/ops/filter/` 目录下创建一个新的算子文件 `text_length_filter.py`,内容如下: - 因为它是一个 Filter 算子,所以需要继承 `base_op.py` 中的 `Filter` 基类,并用 `OPERATORS` 修饰以实现自动注册。 + - 为了方便实现,我们可以以单样本处理的方式实现两个核心方法 `compute_stats_single` 和 `process_single`,它们的输入输出均为单个样本的字典结构。如果你比较熟悉 Data-Juicer 中的batch化处理,你也可以通过覆写 `compute_stats_batched` 和 `process_batched` 方法直接实现它们的batch化版本,它的处理会比单样本版本稍快一些。它们的输入和输出则是按列存储的字典结构,其中包括多个样本。 ```python import sys @@ -84,7 +85,7 @@ class StatsKeys(object): self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample): + def compute_stats_single(self, sample): # check if it's computed already if StatsKeys.text_len in sample[Fields.stats]: return sample @@ -92,14 +93,14 @@ class StatsKeys(object): sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) return sample - def process(self, sample): + def process_single(self, sample): if self.min_len <= sample[Fields.stats][StatsKeys.text_len] <= self.max_len: return True else: return False ``` - - 如果在算子中使用了 Hugging Face 模型,您可能希望利用 GPU 加速。为了实现这一点,请在构造函数中声明 `_accelerator = 'cuda'`,并确保 `compute_stats` 和 `process` 方法接受一个额外的位置参数 `rank`。 + - 如果在算子中使用了 Hugging Face 模型,您可能希望利用 GPU 加速。为了实现这一点,请在构造函数中声明 `_accelerator = 'cuda'`,并确保 `compute_stats_single/batched` 和 `process_single/batched` 方法接受一个额外的位置参数 `rank`。 ```python # ... (same as above) @@ -116,14 +117,15 @@ class StatsKeys(object): **kwargs): # ... (same as above) - def compute_stats(self, sample, rank=None): + def compute_stats_single(self, sample, rank=None): # ... (same as above) - def process(self, sample, rank=None): + def process_single(self, sample, rank=None): # ... (same as above) ``` - - 如果算子批量处理数据,输入不是一个样本而是一个batch,需要声明`_batched_op = True`。 + - 如果算子批量处理数据,输入不是一个样本而是一个batch,或者你想在单样本实现上直接激活batch化处理,需要声明`_batched_op = True`。 + - 对于单样本实现中原来的 `compute_stats_single` 和 `process_single` 方法,你可以保持它们不变,Data-Juicer 会调用默认的batch化处理版本,它们会自动拆分单个样本以调用单样本版本的两个方法来支持batch化处理。你也可以自行实现更高效的batch化的版本。 ```python # ... (import some other libraries) OP_NAME = 'image_diffusion_mapper' @@ -138,7 +140,7 @@ class StatsKeys(object): **kwargs): super().__init__(*args, **kwargs) - def process(self, samples): + def process_batched(self, samples): # ... (some codes) ``` @@ -156,7 +158,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample): + def process_single(self, sample): # ... (some codes) # captions[index] is the prompt for diffusion model related_parameters = self.add_parameters( @@ -179,7 +181,7 @@ class StatsKeys(object): super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - def process(self, sample): + def process_single(self, sample): # ... (some codes) split_video_path = transfer_filename( original_video_path, OP_NAME, **self._init_parameters) @@ -373,7 +375,7 @@ class PerplexityFilter(Filter): AUTOINSTALL.check(['sentencepiece', 'kenlm']) # ... (some codes) - def process(self, sample): + def process_single(self, sample): # ... (some codes) ``` From 23f1a4bc75f424876369998ceb037317f922cadf Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 18 Oct 2024 11:08:25 +0800 Subject: [PATCH 11/13] * DO NOT allow to override the compute_stats or process methods in the subclass of Mapper and Filter --- data_juicer/ops/base_op.py | 20 +++++++++++++++++++ .../deduplicator/ray_basic_deduplicator.py | 4 ++-- data_juicer/ops/op_fusion.py | 4 ++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index ef4307a8c..6eecab75f 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -240,6 +240,16 @@ def __init__(self, *args, **kwargs): else: self.process = catch_map_single_exception(self.process_single) + # set the process method is not allowed to be overridden + def __init_subclass__(cls, **kwargs): + not_allowed_list = ['process'] + for method_name in not_allowed_list: + if method_name in cls.__dict__: + raise TypeError( + f'Method {method_name} cannot be overridden by subclass ' + f'{cls.__name__}. Please implement {method_name}_single ' + f'or {method_name}_batched.') + def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) @@ -304,6 +314,16 @@ def __init__(self, *args, **kwargs): self.compute_stats_single) self.process = catch_map_single_exception(self.process_single) + # set the process method is not allowed to be overridden + def __init_subclass__(cls, **kwargs): + not_allowed_list = ['compute_stats', 'process'] + for method_name in not_allowed_list: + if method_name in cls.__dict__: + raise TypeError( + f'Method {method_name} cannot be overridden by subclass ' + f'{cls.__name__}. Please implement {method_name}_single ' + f'or {method_name}_batched.') + def compute_stats_batched(self, samples, *args, **kwargs): keys = samples.keys() num_samples = len(samples[Fields.stats]) diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index f8c40525e..876d65a94 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -50,7 +50,7 @@ def calculate_hash(self, sample, context=False): """Calculate hash value for the sample.""" raise NotImplementedError - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # init redis client r = redis.StrictRedis(host=self.redis_host, port=self.redis_port, db=0) # compute hash @@ -59,5 +59,5 @@ def compute_stats(self, sample, context=False): sample[HashKeys.is_duplicate] = r.setnx(md5_value, 1) return sample - def process(self, sample): + def process_single(self, sample): return sample[HashKeys.is_duplicate] diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index dfc37d45d..26aaa556e 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -144,7 +144,7 @@ def __init__(self, fused_filters: List): if 'cuda' in accelerator_methods: self.accelerator = 'cuda' - def compute_stats(self, sample, rank=None): + def compute_stats_single(self, sample, rank=None): import av # context for the intermediate vars @@ -165,7 +165,7 @@ def compute_stats(self, sample, rank=None): _ = sample.pop(Fields.context) return sample - def process(self, sample): + def process_single(self, sample): # Only return True when all filters return True for op in self.fused_filters: if not op.process(sample): From 5ce818f1c6aec9452039de655e6c6a7c4f75d4e8 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 18 Oct 2024 11:09:34 +0800 Subject: [PATCH 12/13] * rename the methods for the newly-added OP image_face_count_filter --- data_juicer/ops/filter/image_face_count_filter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_juicer/ops/filter/image_face_count_filter.py b/data_juicer/ops/filter/image_face_count_filter.py index 8b14521f2..be34a28ad 100644 --- a/data_juicer/ops/filter/image_face_count_filter.py +++ b/data_juicer/ops/filter/image_face_count_filter.py @@ -75,7 +75,7 @@ def __init__(self, self.model_key = prepare_model(model_type='opencv_classifier', model_path=cv_classifier) - def compute_stats(self, sample, context=False): + def compute_stats_single(self, sample, context=False): # check if it's computed already if StatsKeys.face_ratios in sample[Fields.stats]: return sample @@ -109,7 +109,7 @@ def compute_stats(self, sample, context=False): ] return sample - def process(self, sample): + def process_single(self, sample): face_counts = sample[Fields.stats][StatsKeys.face_counts] if len(face_counts) <= 0: return True From 15e7e4a86a9bdfc21f9cc3e9ae8ca4f31b6f55ed Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 18 Oct 2024 15:15:42 +0800 Subject: [PATCH 13/13] * bug fixed: unaligned number of extracted frames from extract_video_frames_uniformly --- data_juicer/utils/mm_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 5b3ec0430..fad4d9740 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -546,13 +546,18 @@ def extract_video_frames_uniformly( container.seek(0) search_idx = 0 curr_pts = second_group[search_idx] / time_base + find_all = False for frame in container.decode(input_video_stream): if frame.pts >= curr_pts: extracted_frames.append(frame) search_idx += 1 if search_idx >= len(second_group): + find_all = True break curr_pts = second_group[search_idx] / time_base + if not find_all and frame is not None: + # add the last frame + extracted_frames.append(frame) else: # search from a key frame container.seek(int(key_frame_second * 1e6))