diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 3e1e3ad29..4e0fcbe6e 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -233,7 +233,9 @@ def map(self, *args, **kargs): called_func.__self__, 'is_batched_op') and called_func.__self__.is_batched_op(): kargs['batched'] = True - kargs['batch_size'] = kargs.pop('batch_size', 1) + kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr( + called_func.__self__, 'is_batched_op' + ) and called_func.__self__.is_batched_op() else 1 else: kargs['batched'] = False @@ -266,12 +268,25 @@ def filter(self, *args, **kargs): args[0] = lambda x: nested_obj_factory(x) else: args[0] = wrap_func_with_nested_access(args[0]) + called_func = args[0] else: if 'function' not in kargs or kargs['function'] is None: kargs['function'] = lambda x: nested_obj_factory(x) else: kargs['function'] = wrap_func_with_nested_access( kargs['function']) + called_func = kargs['function'] + + # For wrapped function, try to get its unwrapped (bound) method + while not inspect.ismethod(called_func) and hasattr( + called_func, '__wrapped__'): + called_func = called_func.__wrapped__ + + # Batched is always required for fault tolerance + if inspect.ismethod( + called_func) and called_func.__self__.is_batched_op(): + kargs['batched'] = True + kargs['batch_size'] = kargs.pop('batch_size', 1) if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None: new_fingerprint = generate_fingerprint(self, *args, **kargs) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 17235a2b8..ce964a3af 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -108,14 +108,16 @@ def _run_single_op(self, op): self.num_proc, op.use_cuda()) num_gpus = get_num_gpus(op, op_proc) try: + batch_size = getattr(op, 'batch_size', + 1) if op.is_batched_op() else 1 if isinstance(op, Mapper): self.data = self.data.map_batches(op.process, - batch_size=1, + batch_size=batch_size, batch_format='pyarrow', num_gpus=num_gpus) elif isinstance(op, Filter): self.data = self.data.map_batches(op.compute_stats, - batch_size=1, + batch_size=batch_size, batch_format='pyarrow', num_gpus=num_gpus) if op.stats_export_path is not None: diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 43433704e..b866b7f87 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -133,6 +133,7 @@ def __init__(self, *args, **kwargs): self.image_key = kwargs.get('image_key', 'images') self.audio_key = kwargs.get('audio_key', 'audios') self.video_key = kwargs.get('video_key', 'videos') + self.batch_size = kwargs.get('batch_size', 1) # whether the model can be accelerated using cuda _accelerator = kwargs.get('accelerator', None) @@ -241,6 +242,7 @@ def run(self, dataset, *, exporter=None, tracer=None): self.process, num_proc=self.runtime_np(), with_rank=self.use_cuda(), + batch_size=self.batch_size, desc=self._name + '_process', ) if tracer: @@ -304,15 +306,18 @@ def run(self, dataset, *, exporter=None, tracer=None): 'initial_value': {} }, num_proc=self.runtime_np(), + batch_size=self.batch_size, desc='Adding new column for stats') dataset = dataset.map(self.compute_stats, num_proc=self.runtime_np(), with_rank=self.use_cuda(), + batch_size=self.batch_size, desc=self._name + '_compute_stats') if self.stats_export_path is not None: exporter.export_compute_stats(dataset, self.stats_export_path) new_dataset = dataset.filter(self.process, num_proc=self.runtime_np(), + batch_size=self.batch_size, desc=self._name + '_process') if tracer: tracer.trace_filter(self._name, dataset, new_dataset) diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index e1cf90927..17361b29c 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -18,6 +18,8 @@ class AlphanumericFilter(Filter): """Filter to keep samples with alphabet/numeric ratio within a specific range.""" + _batched_op = True + def __init__(self, tokenization: bool = False, min_ratio: float = 0.25, @@ -52,36 +54,46 @@ def __init__(self, pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped', return_model=False) - def compute_stats(self, sample): - if self.tokenization: - if StatsKeys.alpha_token_ratio in sample[Fields.stats]: - return sample - alpha_count = sum( - map(lambda char: 1 - if char.isalpha() else 0, sample[self.text_key])) - tokenizer = get_model(self.model_key) - token_count = len( - get_words_from_document( - sample[self.text_key], - token_func=tokenizer.tokenize if tokenizer else None)) - sample[Fields.stats][StatsKeys.alpha_token_ratio] = ( - alpha_count / token_count) if token_count != 0 else 0.0 - else: - if StatsKeys.alnum_ratio in sample[Fields.stats]: - return sample - alnum_count = sum( - map(lambda char: 1 - if char.isalnum() else 0, sample[self.text_key])) - sample[Fields.stats][StatsKeys.alnum_ratio] = ( - alnum_count / len(sample[self.text_key])) if len( - sample[self.text_key]) != 0 else 0.0 - return sample + def compute_stats(self, samples): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + + for idx, stat in enumerate(samples_stats): + cur_text = samples_list[idx] + if self.tokenization: + if StatsKeys.alpha_token_ratio in stat: + continue + alpha_count = sum( + map(lambda char: 1 if char.isalpha() else 0, cur_text)) + tokenizer = get_model(self.model_key) + token_count = len( + get_words_from_document( + cur_text, + token_func=tokenizer.tokenize if tokenizer else None)) + samples_stats[idx][StatsKeys.alpha_token_ratio] = ( + alpha_count / token_count) if token_count != 0 else 0.0 + else: + if StatsKeys.alnum_ratio in stat: + continue + alnum_count = sum( + map(lambda char: 1 if char.isalnum() else 0, cur_text)) + samples_stats[idx][StatsKeys.alnum_ratio] = ( + alnum_count / len(cur_text)) if len(cur_text) != 0 else 0.0 + + return samples - def process(self, sample): - ratio = sample[Fields.stats][ - StatsKeys.alpha_token_ratio] if self.tokenization else sample[ - Fields.stats][StatsKeys.alnum_ratio] - if self.min_ratio <= ratio <= self.max_ratio: - return True + def process(self, samples): + ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \ + else StatsKeys.alnum_ratio + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_ratio <= stat[ratio_key] <= self. + max_ratio, samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_ratio <= samples[ + Fields.stats][ratio_key] <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py index 079a6d9f3..74d624a82 100644 --- a/data_juicer/ops/filter/average_line_length_filter.py +++ b/data_juicer/ops/filter/average_line_length_filter.py @@ -5,13 +5,17 @@ from ..base_op import OPERATORS, Filter from ..op_fusion import INTER_LINES +OP_NAME = 'average_line_length_filter' -@OPERATORS.register_module('average_line_length_filter') -@INTER_LINES.register_module('average_line_length_filter') + +@OPERATORS.register_module(OP_NAME) +@INTER_LINES.register_module(OP_NAME) class AverageLineLengthFilter(Filter): """Filter to keep samples with average line length within a specific range.""" + _batched_op = True + def __init__(self, min_len: int = 10, max_len: int = sys.maxsize, @@ -33,26 +37,38 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.avg_line_length in sample[Fields.stats]: - return sample - + def compute_stats(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] context_key = f'{InterVars.lines}' - if context and context_key in sample[Fields.context]: - lines = sample[Fields.context][context_key] - else: - lines = sample[self.text_key].splitlines() - if context: - sample[Fields.context][context_key] = lines - sample[Fields.stats][StatsKeys.avg_line_length] = \ - len(sample[self.text_key]) / len(lines) \ - if len(lines) != 0 else 0.0 - return sample - - def process(self, sample): - if self.min_len <= sample[Fields.stats][ - StatsKeys.avg_line_length] <= self.max_len: - return True + + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.avg_line_length in stat: + continue + + cur_text = samples_list[idx] + if context and context_key in samples[Fields.context][idx]: + lines = samples[Fields.context][idx][context_key] + else: + lines = cur_text.splitlines() + if context: + samples[Fields.context][idx][context_key] = lines + samples_stats[idx][StatsKeys.avg_line_length] = \ + len(cur_text) / len(lines) if len(lines) != 0 else 0.0 + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_len <= stat[StatsKeys.avg_line_length + ] <= self.max_len, + samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_len <= samples[Fields.stats][ + StatsKeys.avg_line_length] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py index 1fb6949ff..a0441334a 100644 --- a/data_juicer/ops/filter/character_repetition_filter.py +++ b/data_juicer/ops/filter/character_repetition_filter.py @@ -15,6 +15,8 @@ class CharacterRepetitionFilter(Filter): """Filter to keep samples with char-level n-gram repetition ratio within a specific range.""" + _batched_op = True + def __init__(self, rep_len: PositiveInt = 10, min_ratio: float = 0.0, @@ -39,40 +41,54 @@ def __init__(self, self.min_ratio = min_ratio self.max_ratio = max_ratio - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.char_rep_ratio in sample[Fields.stats]: - return sample + def compute_stats(self, samples): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.char_rep_ratio in stat: + continue + + cur_text = samples_list[idx] + char_ngrams = [ + cur_text[i:i + self.n] + for i in range(len(cur_text) - self.n + 1) + ] + freq_char_ngrams = {} + for char_ngram in char_ngrams: + freq_char_ngrams[char_ngram] = ( + freq_char_ngrams.get(char_ngram, 0) + 1) - char_ngrams = [ - sample[self.text_key][i:i + self.n] - for i in range(len(sample[self.text_key]) - self.n + 1) - ] - freq_char_ngrams = {} - for char_ngram in char_ngrams: - freq_char_ngrams[char_ngram] = ( - freq_char_ngrams.get(char_ngram, 0) + 1) + if len(freq_char_ngrams) == 0: + samples_stats[idx][StatsKeys.char_rep_ratio] = 0.0 + continue - if len(freq_char_ngrams) == 0: - sample[Fields.stats][StatsKeys.char_rep_ratio] = 0.0 - return sample + freq_char_ngrams = sorted(list(freq_char_ngrams.values()), + reverse=True) + num_no_rep_char_ngrams = len( + [el for el in freq_char_ngrams if el == 1]) + num_rep_char_ngrams = min( + int(np.sqrt(len(freq_char_ngrams))), + len(freq_char_ngrams) - num_no_rep_char_ngrams, + ) + samples_stats[idx][StatsKeys.char_rep_ratio] = ( + sum(freq_char_ngrams[:num_rep_char_ngrams]) / + sum(freq_char_ngrams)) if sum(freq_char_ngrams) != 0 else 0.0 - freq_char_ngrams = sorted(list(freq_char_ngrams.values()), - reverse=True) - num_no_rep_char_ngrams = len( - [el for el in freq_char_ngrams if el == 1]) - num_rep_char_ngrams = min( - int(np.sqrt(len(freq_char_ngrams))), - len(freq_char_ngrams) - num_no_rep_char_ngrams, - ) - sample[Fields.stats][StatsKeys.char_rep_ratio] = (sum( - freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams)) \ - if sum(freq_char_ngrams) != 0 else 0.0 - return sample + return samples - def process(self, sample): - if self.min_ratio <= sample[Fields.stats][StatsKeys.char_rep_ratio] \ - <= self.max_ratio: - return True + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_ratio <= stat[ + StatsKeys.char_rep_ratio] <= self.max_ratio, + samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_ratio <= samples[Fields.stats][ + StatsKeys.char_rep_ratio] <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py index 2f2a4513e..146cfb0a2 100644 --- a/data_juicer/ops/filter/maximum_line_length_filter.py +++ b/data_juicer/ops/filter/maximum_line_length_filter.py @@ -5,13 +5,17 @@ from ..base_op import OPERATORS, Filter from ..op_fusion import INTER_LINES +OP_NAME = 'maximum_line_length_filter' -@OPERATORS.register_module('maximum_line_length_filter') -@INTER_LINES.register_module('maximum_line_length_filter') + +@OPERATORS.register_module(OP_NAME) +@INTER_LINES.register_module(OP_NAME) class MaximumLineLengthFilter(Filter): """Filter to keep samples with maximum line length within a specific range.""" + _batched_op = True + def __init__(self, min_len: int = 10, max_len: int = sys.maxsize, @@ -33,26 +37,39 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.max_line_length in sample[Fields.stats]: - return sample - + def compute_stats(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] context_key = f'{InterVars.lines}' - if context and context_key in sample[Fields.context]: - lines = sample[Fields.context][context_key] - else: - lines = sample[self.text_key].splitlines() - if context: - sample[Fields.context][context_key] = lines - line_lengths = list(map(len, lines)) - sample[Fields.stats][StatsKeys.max_line_length] = max( - line_lengths) if line_lengths else 0 - return sample - - def process(self, sample): - if self.min_len <= sample[Fields.stats][ - StatsKeys.max_line_length] <= self.max_len: - return True + + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.max_line_length in stat: + continue + + if context and context_key in samples[Fields.context][idx]: + lines = samples[Fields.context][idx][context_key] + else: + lines = samples_list[idx].splitlines() + if context: + samples[Fields.context][idx][context_key] = lines + line_lengths = list(map(len, lines)) + samples_stats[idx][StatsKeys.max_line_length] = max( + line_lengths) if line_lengths else 0 + + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_len <= stat[StatsKeys.max_line_length + ] <= self.max_len, + samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_len <= samples[Fields.stats][ + StatsKeys.max_line_length] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index 5d0b396f9..287d15a11 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -23,6 +23,8 @@ class PerplexityFilter(Filter): """Filter to keep samples with perplexity score less than a specific max value.""" + _batched_op = True + def __init__(self, lang: str = 'en', max_ppl: float = 1500, @@ -44,33 +46,42 @@ def __init__(self, lang=lang) self.kl_model_key = prepare_model(model_type='kenlm', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.perplexity in sample[Fields.stats]: - return sample - - # tokenization + def compute_stats(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.sp_model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.sp_model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - text = ' '.join(words) - # compute perplexity - logits, length = 0, 0 - kenlm_model = get_model(self.kl_model_key) - for line in text.splitlines(): - logits += kenlm_model.score(line) - length += (len(line.split()) + 1) - ppl = (10.0**(-logits / length)) if length != 0 else 0.0 - sample[Fields.stats][StatsKeys.perplexity] = round(ppl, 1) - return sample + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.perplexity in stat: + continue + # tokenization + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] + else: + tokenizer = get_model(self.sp_model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][idx][words_key] = words + text = ' '.join(words) + # compute perplexity + logits, length = 0, 0 + kenlm_model = get_model(self.kl_model_key) + for line in text.splitlines(): + logits += kenlm_model.score(line) + length += (len(line.split()) + 1) + ppl = (10.0**(-logits / length)) if length != 0 else 0.0 + samples_stats[idx][StatsKeys.perplexity] = round(ppl, 1) - def process(self, sample): - return sample[Fields.stats][StatsKeys.perplexity] <= self.max_ppl + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl, + samples[Fields.stats])) + else: + return samples[Fields.stats][StatsKeys.perplexity] <= self.max_ppl diff --git a/data_juicer/ops/filter/special_characters_filter.py b/data_juicer/ops/filter/special_characters_filter.py index dc9ef1ed6..0b56f390e 100644 --- a/data_juicer/ops/filter/special_characters_filter.py +++ b/data_juicer/ops/filter/special_characters_filter.py @@ -13,6 +13,8 @@ class SpecialCharactersFilter(Filter): """Filter to keep samples with special-char ratio within a specific range.""" + _batched_op = True + def __init__(self, min_ratio: float = 0.0, max_ratio: float = 0.25, @@ -34,23 +36,34 @@ def __init__(self, self.min_ratio = min_ratio self.max_ratio = max_ratio - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.special_char_ratio in sample[Fields.stats]: - return sample - - # get ratio of special characters - sample[Fields.stats][StatsKeys.special_char_ratio] = ( - len([c - for c in sample[self.text_key] if c in SPECIAL_CHARACTERS]) / - len(sample[self.text_key])) if len( - sample[self.text_key]) != 0 else 0.0 - return sample - - def process(self, sample): - if self.min_ratio <= \ - sample[Fields.stats][StatsKeys.special_char_ratio] \ - <= self.max_ratio: - return True + def compute_stats(self, samples): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.special_char_ratio in stat: + continue + cur_text = samples_list[idx] + # get ratio of special characters + samples_stats[idx][StatsKeys.special_char_ratio] = ( + len([c for c in cur_text if c in SPECIAL_CHARACTERS]) / + len(cur_text)) if len(cur_text) != 0 else 0.0 + + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_ratio <= stat[ + StatsKeys.special_char_ratio] <= self.max_ratio, + samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_ratio <= \ + samples[Fields.stats][StatsKeys.special_char_ratio] \ + <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py index 6fa966889..ec61f8304 100644 --- a/data_juicer/ops/filter/text_length_filter.py +++ b/data_juicer/ops/filter/text_length_filter.py @@ -10,6 +10,8 @@ class TextLengthFilter(Filter): """Filter to keep samples with total text length within a specific range.""" + _batched_op = True + def __init__(self, min_len: int = 10, max_len: int = sys.maxsize, @@ -31,17 +33,28 @@ def __init__(self, self.min_len = min_len self.max_len = max_len - def compute_stats(self, sample): - # check if it's computed already - if StatsKeys.text_len in sample[Fields.stats]: - return sample - - sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key]) - return sample - - def process(self, sample): - if self.min_len <= sample[Fields.stats][ - StatsKeys.text_len] <= self.max_len: - return True + def compute_stats(self, samples): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] + for i, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.text_len in stat: + continue + else: + samples_stats[i][StatsKeys.text_len] = len(samples_list[i]) + + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_len <= stat[StatsKeys.text_len] <= + self.max_len, samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_len <= samples[Fields.stats][ + StatsKeys.text_len] <= self.max_len: + return True + else: + return False diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 3009c55f9..71f806e25 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -25,6 +25,8 @@ class WordRepetitionFilter(Filter): """Filter to keep samples with word-level n-gram repetition ratio within a specific range.""" + _batched_op = True + def __init__(self, lang: str = 'en', tokenization: bool = False, @@ -59,57 +61,70 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.word_rep_ratio in sample[Fields.stats]: - return sample - - # try to get words from context + def compute_stats(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - - # try to get refined words from context - refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \ - f'False-[2]-' - if context and refined_words_key in sample[Fields.context]: - words = sample[Fields.context][refined_words_key] - else: - words = words_refinement(words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS) - if context: - sample[Fields.context][refined_words_key] = words - word_ngrams = [ - ' '.join(words[i:i + self.n]) - for i in range(len(words) - self.n + 1) - ] - freq_word_ngrams = {} - for word_ngram in word_ngrams: - freq_word_ngrams[word_ngram] = ( - freq_word_ngrams.get(word_ngram, 0) + 1) - - if len(freq_word_ngrams) == 0: - sample[Fields.stats][StatsKeys.word_rep_ratio] = 0.0 - return sample - - freq_word_ngrams = list(freq_word_ngrams.values()) - rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] - sample[Fields.stats][StatsKeys.word_rep_ratio] = ( - sum(rep_more_than_one) / - sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 - return sample - - def process(self, sample): - if self.min_ratio <= sample[Fields.stats][StatsKeys.word_rep_ratio] \ - <= self.max_ratio: - return True + + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.word_rep_ratio in stat: + continue + # try to get words from context + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][idx][words_key] = words + + # try to get refined words from context + refined_words_key = f'{InterVars.refined_words}-' \ + f'True-SPECIAL_CHARS-False-[2]-' + if context and refined_words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][refined_words_key] + else: + words = words_refinement(words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS) + if context: + samples[Fields.context][idx][refined_words_key] = words + word_ngrams = [ + ' '.join(words[i:i + self.n]) + for i in range(len(words) - self.n + 1) + ] + freq_word_ngrams = {} + for word_ngram in word_ngrams: + freq_word_ngrams[word_ngram] = ( + freq_word_ngrams.get(word_ngram, 0) + 1) + + if len(freq_word_ngrams) == 0: + samples_stats[idx][StatsKeys.word_rep_ratio] = 0.0 + continue + + freq_word_ngrams = list(freq_word_ngrams.values()) + rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1] + samples_stats[idx][StatsKeys.word_rep_ratio] = ( + sum(rep_more_than_one) / + sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0 + + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_ratio <= stat[ + StatsKeys.word_rep_ratio] <= self.max_ratio, + samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_ratio <= samples[Fields.stats][ + StatsKeys.word_rep_ratio] <= self.max_ratio: + return True + else: + return False diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index 7d354cb54..07eb8e2b7 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -21,6 +21,8 @@ class WordsNumFilter(Filter): """Filter to keep samples with total words number within a specific range.""" + _batched_op = True + def __init__(self, lang: str = 'en', tokenization: bool = False, @@ -52,28 +54,40 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats(self, sample, context=False): - # check if it's computed already - if StatsKeys.num_words in sample[Fields.stats]: - return sample - + def compute_stats(self, samples, context=False): + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) - sample[Fields.stats][StatsKeys.num_words] = len(words) - return sample - def process(self, sample): - if self.min_num <= sample[Fields.stats][ - StatsKeys.num_words] <= self.max_num: - return True + for idx, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.num_words in stat: + continue + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] + else: + tokenizer = get_model(self.model_key) + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][idx][words_key] = words + words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS) + samples_stats[idx][StatsKeys.num_words] = len(words) + + return samples + + def process(self, samples): + if isinstance(samples[Fields.stats], list): + return list( + map( + lambda stat: self.min_num <= stat[StatsKeys.num_words] <= + self.max_num, samples[Fields.stats])) else: - return False + # single sample for ray filter + if self.min_num <= samples[Fields.stats][ + StatsKeys.num_words] <= self.max_num: + return True + else: + return False diff --git a/data_juicer/ops/mapper/chinese_convert_mapper.py b/data_juicer/ops/mapper/chinese_convert_mapper.py index 818f1b1d4..9236ddaa2 100644 --- a/data_juicer/ops/mapper/chinese_convert_mapper.py +++ b/data_juicer/ops/mapper/chinese_convert_mapper.py @@ -27,6 +27,8 @@ class ChineseConvertMapper(Mapper): """Mapper to convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji.""" + _batched_op = True + def __init__(self, mode: str = 's2t', *args, **kwargs): """ Initialization method. @@ -82,8 +84,10 @@ def __init__(self, mode: str = 's2t', *args, **kwargs): self.mode = mode prepare_converter(self.mode) - def process(self, sample): + def process(self, samples): prepare_converter(self.mode) - sample[self.text_key] = OPENCC_CONVERTER.convert(sample[self.text_key]) - return sample + samples[self.text_key] = list( + map(lambda text: OPENCC_CONVERTER.convert(text), + samples[self.text_key])) + return samples diff --git a/data_juicer/ops/mapper/clean_copyright_mapper.py b/data_juicer/ops/mapper/clean_copyright_mapper.py index dabb0cd40..3bf6fcbdf 100644 --- a/data_juicer/ops/mapper/clean_copyright_mapper.py +++ b/data_juicer/ops/mapper/clean_copyright_mapper.py @@ -12,6 +12,8 @@ class CleanCopyrightMapper(Mapper): """Mapper to clean copyright comments at the beginning of the text samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -23,21 +25,19 @@ def __init__(self, *args, **kwargs): self.pat = re.compile('/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/') self.cpat = re.compile('copyright', re.IGNORECASE) - def process(self, sample): - - r = self.pat.search(sample[self.text_key]) + def _process_single_sample(self, sample): + r = self.pat.search(sample) if r: # found one, now see if it contains "copyright", if so strip it span = r.span() - sub = sample[self.text_key][span[0]:span[1]] + sub = sample[span[0]:span[1]] if self.cpat.search(sub): # cut it - sample[self.text_key] = sample[ - self.text_key][:span[0]] + sample[self.text_key][span[1]:] + sample = sample[:span[0]] + sample[span[1]:] return sample - lines = sample[self.text_key].split('\n') + lines = sample.split('\n') skip = 0 # Greedy replace any file that begins with comment block, most @@ -51,5 +51,11 @@ def process(self, sample): if skip: # we skipped, consume it - sample[self.text_key] = '\n'.join(lines[skip:]) + sample = '\n'.join(lines[skip:]) return sample + + def process(self, samples): + samples[self.text_key] = list( + map(lambda text: self._process_single_sample(text), + samples[self.text_key])) + return samples diff --git a/data_juicer/ops/mapper/clean_email_mapper.py b/data_juicer/ops/mapper/clean_email_mapper.py index b8d2a1cbb..90dcca60f 100644 --- a/data_juicer/ops/mapper/clean_email_mapper.py +++ b/data_juicer/ops/mapper/clean_email_mapper.py @@ -9,6 +9,8 @@ class CleanEmailMapper(Mapper): """Mapper to clean email in text samples.""" + _batched_op = True + def __init__(self, pattern: Optional[str] = None, repl: str = '', @@ -34,13 +36,13 @@ def __init__(self, self.repl = repl - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue + samples[self.text_key][idx] = re.sub(pattern=self.pattern, + repl=self.repl, + string=text, + flags=re.DOTALL) - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) - return sample + return samples diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py index 5c2c30c57..d959cc85f 100644 --- a/data_juicer/ops/mapper/clean_html_mapper.py +++ b/data_juicer/ops/mapper/clean_html_mapper.py @@ -16,6 +16,8 @@ class CleanHtmlMapper(Mapper): """Mapper to clean html code in text samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -25,7 +27,7 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - def process(self, sample): + def process(self, samples): def _clean_html(raw_html): raw_html = raw_html.replace('
  • ', '\n*') @@ -35,5 +37,6 @@ def _clean_html(raw_html): parser = HTMLParser(raw_html) return parser.text() - sample[self.text_key] = _clean_html(sample[self.text_key]) - return sample + samples[self.text_key] = list( + map(lambda text: _clean_html(text), samples[self.text_key])) + return samples diff --git a/data_juicer/ops/mapper/clean_ip_mapper.py b/data_juicer/ops/mapper/clean_ip_mapper.py index b36d13aae..709037ddd 100644 --- a/data_juicer/ops/mapper/clean_ip_mapper.py +++ b/data_juicer/ops/mapper/clean_ip_mapper.py @@ -9,6 +9,8 @@ class CleanIpMapper(Mapper): """Mapper to clean ipv4 and ipv6 address in text samples.""" + _batched_op = True + def __init__(self, pattern: Optional[str] = None, repl: str = '', @@ -38,13 +40,12 @@ def __init__(self, self.pattern = pattern[2:-1] self.repl = repl - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample - - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue + samples[self.text_key][idx] = re.sub(pattern=self.pattern, + repl=self.repl, + string=text, + flags=re.DOTALL) + return samples diff --git a/data_juicer/ops/mapper/clean_links_mapper.py b/data_juicer/ops/mapper/clean_links_mapper.py index ebeac8668..f08abc78f 100644 --- a/data_juicer/ops/mapper/clean_links_mapper.py +++ b/data_juicer/ops/mapper/clean_links_mapper.py @@ -12,6 +12,8 @@ class CleanLinksMapper(Mapper): """Mapper to clean links like http/https/ftp in text samples.""" + _batched_op = True + def __init__(self, pattern: Optional[str] = None, repl: str = '', @@ -44,13 +46,13 @@ def __init__(self, self.pattern = pattern[2:-1] self.repl = repl - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) - return sample + samples[self.text_key][idx] = re.sub(pattern=self.pattern, + repl=self.repl, + string=text, + flags=re.DOTALL) + return samples diff --git a/data_juicer/ops/mapper/expand_macro_mapper.py b/data_juicer/ops/mapper/expand_macro_mapper.py index 2f5d7fe83..b83455103 100644 --- a/data_juicer/ops/mapper/expand_macro_mapper.py +++ b/data_juicer/ops/mapper/expand_macro_mapper.py @@ -12,6 +12,8 @@ class ExpandMacroMapper(Mapper): """Mapper to expand macro definitions in the document body of Latex samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -55,26 +57,29 @@ def _build_non_arg_macros_dict(self, file_content): macros[macro_name] = macro_val return macros - def process(self, sample): - non_arg_macros = self._build_non_arg_macros_dict(sample[self.text_key]) - - # TODO: macros that take arguments are not supported yet - arg_macros = {} - - # inline-expand all non-arg macros - for macro_name, macro_value in non_arg_macros.items(): - sample[self.text_key] = re.sub( - # make pattern grouped to make sure that the macro is not part - # of a longer alphanumeric word - pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])', - # replace the macro with its value and add back the character - # that was matched after the macro - repl=macro_value + r'\2', - string=sample[self.text_key]) - - # inline-expand all macros that use args - # TODO: inline-expand macros with args - for macro_name, macro_value in arg_macros.items(): - pass - - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + non_arg_macros = self._build_non_arg_macros_dict(text) + + # TODO: macros that take arguments are not supported yet + arg_macros = {} + + # inline-expand all non-arg macros + for macro_name, macro_value in non_arg_macros.items(): + text = re.sub( + # make pattern grouped to make sure that the macro + # is not part of a longer alphanumeric word + pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])', + # replace the macro with its value and add back the + # character that was matched after the macro + repl=macro_value + r'\2', + string=text) + + # inline-expand all macros that use args + # TODO: inline-expand macros with args + for macro_name, macro_value in arg_macros.items(): + pass + + samples[self.text_key][idx] = text + + return samples diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py index a7d06da3a..4ca71c30a 100644 --- a/data_juicer/ops/mapper/fix_unicode_mapper.py +++ b/data_juicer/ops/mapper/fix_unicode_mapper.py @@ -12,6 +12,8 @@ class FixUnicodeMapper(Mapper): """Mapper to fix unicode errors in text samples.""" + _batched_op = True + def __init__(self, normalization: str = None, *args, **kwargs): """ Initialization method. @@ -33,7 +35,10 @@ def __init__(self, normalization: str = None, *args, **kwargs): 'supported. Can only be one of ' '["NFC", "NFKC", "NFD", "NFKD"]') - def process(self, sample): - sample[self.text_key] = ftfy.fix_text(sample[self.text_key], - normalization=self.normalization) - return sample + def process(self, samples): + samples[self.text_key] = list( + map( + lambda text: ftfy.fix_text(text, + normalization=self.normalization), + samples[self.text_key])) + return samples diff --git a/data_juicer/ops/mapper/punctuation_normalization_mapper.py b/data_juicer/ops/mapper/punctuation_normalization_mapper.py index b6640e9eb..6531833a3 100644 --- a/data_juicer/ops/mapper/punctuation_normalization_mapper.py +++ b/data_juicer/ops/mapper/punctuation_normalization_mapper.py @@ -10,6 +10,8 @@ class PunctuationNormalizationMapper(Mapper): """Mapper to normalize unicode punctuations to English punctuations in text samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -55,8 +57,10 @@ def __init__(self, *args, **kwargs): '►': '-', } - def process(self, sample): - sample[self.text_key] = ''.join([ - self.punctuation_unicode.get(c, c) for c in sample[self.text_key] - ]) - return sample + def process(self, samples): + samples[self.text_key] = list( + map( + lambda text: ''.join( + [self.punctuation_unicode.get(c, c) for c in text]), + samples[self.text_key])) + return samples diff --git a/data_juicer/ops/mapper/remove_bibliography_mapper.py b/data_juicer/ops/mapper/remove_bibliography_mapper.py index 2ce852d66..d2a2bf342 100644 --- a/data_juicer/ops/mapper/remove_bibliography_mapper.py +++ b/data_juicer/ops/mapper/remove_bibliography_mapper.py @@ -12,6 +12,8 @@ class RemoveBibliographyMapper(Mapper): """Mapper to remove bibliography at the end of documents in Latex samples.""" + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -27,9 +29,12 @@ def __init__(self, *args, **kwargs): self.pattern += r'\\bibliography\{.*\}' self.pattern += r').*$' - def process(self, sample): - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=r'', - string=sample[self.text_key], - flags=re.DOTALL) - return sample + def process(self, samples): + samples[self.text_key] = list( + map( + lambda text: re.sub(pattern=self.pattern, + repl=r'', + string=text, + flags=re.DOTALL), samples[self.text_key])) + + return samples diff --git a/data_juicer/ops/mapper/remove_comments_mapper.py b/data_juicer/ops/mapper/remove_comments_mapper.py index c5f083c14..09fe4e5ef 100644 --- a/data_juicer/ops/mapper/remove_comments_mapper.py +++ b/data_juicer/ops/mapper/remove_comments_mapper.py @@ -17,6 +17,8 @@ class RemoveCommentsMapper(Mapper): Only support 'tex' for now. """ + _batched_op = True + def __init__(self, doc_type: Union[str, List[str]] = 'tex', inline: bool = True, @@ -37,19 +39,23 @@ def __init__(self, self.inline = inline self.multiline = multiline - def process(self, sample): + def process(self, samples): # TODO: remove different comments by sample type - if self.inline: - # remove all in comments within a line - sample[self.text_key] = re.sub(pattern=r'[^\\]%.+$', - repl=r'', - string=sample[self.text_key], - flags=re.MULTILINE) - - if self.multiline: - sample[self.text_key] = re.sub(pattern=r'(?m)^%.*\n?', - repl=r'', - string=sample[self.text_key], - flags=re.MULTILINE) - return sample + for idx, text in enumerate(samples[self.text_key]): + if self.inline: + # remove all in comments within a line + text = re.sub(pattern=r'[^\\]%.+$', + repl=r'', + string=text, + flags=re.MULTILINE) + + if self.multiline: + text = re.sub(pattern=r'(?m)^%.*\n?', + repl=r'', + string=text, + flags=re.MULTILINE) + + samples[self.text_key][idx] = text + + return samples diff --git a/data_juicer/ops/mapper/remove_header_mapper.py b/data_juicer/ops/mapper/remove_header_mapper.py index 8371d2f99..bb967e929 100644 --- a/data_juicer/ops/mapper/remove_header_mapper.py +++ b/data_juicer/ops/mapper/remove_header_mapper.py @@ -12,6 +12,8 @@ class RemoveHeaderMapper(Mapper): """Mapper to remove headers at the beginning of documents in Latex samples.""" + _batched_op = True + def __init__(self, drop_no_head: bool = True, *args, **kwargs): """ Initialization method. @@ -34,15 +36,17 @@ def __init__(self, drop_no_head: bool = True, *args, **kwargs): self.drop_no_head = drop_no_head - def process(self, sample): + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + if self.drop_no_head: + text = '' + continue + text = re.sub(pattern=self.pattern, + repl=r'\2', + string=text, + flags=re.DOTALL) - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - if self.drop_no_head: - sample[self.text_key] = '' - return sample + samples[self.text_key][idx] = text - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=r'\2', - string=sample[self.text_key], - flags=re.DOTALL) - return sample + return samples diff --git a/data_juicer/ops/mapper/remove_long_words_mapper.py b/data_juicer/ops/mapper/remove_long_words_mapper.py index ff8fa2d29..5aea47516 100644 --- a/data_juicer/ops/mapper/remove_long_words_mapper.py +++ b/data_juicer/ops/mapper/remove_long_words_mapper.py @@ -13,6 +13,8 @@ class RemoveLongWordsMapper(Mapper): """Mapper to remove long words within a specific range.""" + _batched_op = True + def __init__(self, min_len: int = 1, max_len: int = sys.maxsize, @@ -41,11 +43,13 @@ def should_keep_long_word(self, word): else: return False - def process(self, sample): - - sentences = split_on_newline_tab_whitespace(sample[self.text_key]) - sentences = [[[ - word for word in subsentence if self.should_keep_long_word(word) - ] for subsentence in sentence] for sentence in sentences] - sample[self.text_key] = merge_on_whitespace_tab_newline(sentences) - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + sentences = split_on_newline_tab_whitespace(text) + sentences = [[[ + word for word in subsentence + if self.should_keep_long_word(word) + ] for subsentence in sentence] for sentence in sentences] + samples[self.text_key][idx] = merge_on_whitespace_tab_newline( + sentences) + return samples diff --git a/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py b/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py index 3e6cd494d..371697efc 100644 --- a/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py +++ b/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py @@ -7,6 +7,8 @@ class RemoveNonChineseCharacterlMapper(Mapper): """Mapper to remove non chinese Character in text samples.""" + _batched_op = True + def __init__(self, keep_alphabet: bool = True, keep_number: bool = True, @@ -33,13 +35,13 @@ def __init__(self, else: self.pattern += u']' - def process(self, sample): - - if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL): - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + if not re.search(self.pattern, text, flags=re.DOTALL): + continue - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=r'', - string=sample[self.text_key], - flags=re.DOTALL) - return sample + samples[self.text_key][idx] = re.sub(pattern=self.pattern, + repl=r'', + string=text, + flags=re.DOTALL) + return samples diff --git a/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py b/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py index a1069d24d..add0a719e 100644 --- a/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py +++ b/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py @@ -15,6 +15,8 @@ def split_sentence(text): class RemoveRepeatSentencesMapper(Mapper): """Mapper to remove repeat sentences in text samples.""" + _batched_op = True + def __init__(self, lowercase: bool = False, ignore_special_character: bool = True, @@ -43,28 +45,29 @@ def __init__(self, self.remove_regex = re.compile(r'[^a-zA-Z0-9\u4e00-\u9fa5\n\t ]' ) if ignore_special_character else None - def process(self, sample): + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + lines = [e for e in text.split('\n')] + new_lines = [] + hash_set = set([]) + for line in lines: + new_sent = '' + if line: + sentences = split_sentence(line) + for sentence in sentences: + copy = sentence.strip() + if self.lowercase: + copy = copy.lower() + if self.remove_regex: + copy = self.remove_regex.sub('', copy) - lines = [e for e in sample[self.text_key].split('\n')] - new_lines = [] - hash_set = set([]) - for line in lines: - new_sent = '' - if line: - sentences = split_sentence(line) - for sentence in sentences: - copy = sentence.strip() - if self.lowercase: - copy = copy.lower() - if self.remove_regex: - copy = self.remove_regex.sub('', copy) + if len(copy) < self.min_repeat_sentence_length: + new_sent += sentence + elif copy not in hash_set: + new_sent += sentence + hash_set.add(copy) + new_lines.append(new_sent) - if len(copy) < self.min_repeat_sentence_length: - new_sent += sentence - elif copy not in hash_set: - new_sent += sentence - hash_set.add(copy) - new_lines.append(new_sent) + samples[self.text_key][idx] = '\n'.join(new_lines) - sample[self.text_key] = '\n'.join(new_lines) - return sample + return samples diff --git a/data_juicer/ops/mapper/remove_specific_chars_mapper.py b/data_juicer/ops/mapper/remove_specific_chars_mapper.py index 99e15afef..d487efa2f 100644 --- a/data_juicer/ops/mapper/remove_specific_chars_mapper.py +++ b/data_juicer/ops/mapper/remove_specific_chars_mapper.py @@ -9,6 +9,8 @@ class RemoveSpecificCharsMapper(Mapper): """Mapper to clean specific chars in text samples.""" + _batched_op = True + def __init__(self, chars_to_remove: Union[str, List[str]] = '◆●■►▼▲▴∆▻▷❖♡□', *args, @@ -28,13 +30,14 @@ def __init__(self, else: self.pattern = None - def process(self, sample): - + def process(self, samples): if self.pattern is None: - return sample - - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=r'', - string=sample[self.text_key], - flags=re.DOTALL) - return sample + return samples + + samples[self.text_key] = list( + map( + lambda text: re.sub(pattern=self.pattern, + repl=r'', + string=text, + flags=re.DOTALL), samples[self.text_key])) + return samples diff --git a/data_juicer/ops/mapper/remove_table_text_mapper.py b/data_juicer/ops/mapper/remove_table_text_mapper.py index ca12104c0..8273c8dab 100644 --- a/data_juicer/ops/mapper/remove_table_text_mapper.py +++ b/data_juicer/ops/mapper/remove_table_text_mapper.py @@ -14,6 +14,8 @@ class RemoveTableTextMapper(Mapper): number of tables. """ + _batched_op = True + def __init__(self, min_col: Annotated[int, Field(ge=2, le=20)] = 2, max_col: Annotated[int, Field(ge=2, le=20)] = 20, @@ -32,12 +34,12 @@ def __init__(self, self.max_col = max_col self.pattern = r'(?<=\n)((\S+?)([ |\t](\S+?)){%d}\n+){2,}' - def process(self, sample): + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + for idx in range(self.min_col - 1, self.max_col): + pattern = re.compile(self.pattern % idx) + text = pattern.sub('', text) - text = sample[self.text_key] - for i in range(self.min_col - 1, self.max_col): - pattern = re.compile(self.pattern % i) - text = pattern.sub('', text) + samples[self.text_key][idx] = text - sample[self.text_key] = text - return sample + return samples diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py index d262c1d17..97bb61319 100644 --- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -18,6 +18,8 @@ class RemoveWordsWithIncorrectSubstringsMapper(Mapper): """Mapper to remove words with incorrect substrings.""" + _batched_op = True + def __init__(self, lang: str = 'en', tokenization: bool = False, @@ -48,25 +50,30 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings): should_keep = all([(i_substr not in word) for i_substr in substrings]) return should_keep - def process(self, sample): - if self.tokenization: - tokenizer = get_model(self.model_key) - sentences = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - words = [ - word.replace('▁', '') for word in sentences - if self.should_keep_word_with_incorrect_substrings( - word.replace('▁', ''), self.substrings) - ] - if len(words) != len(sentences): - sample[self.text_key] = ''.join(words) - else: - sentences = split_on_newline_tab_whitespace(sample[self.text_key]) - sentences = [[[ - word for word in subsentence - if self.should_keep_word_with_incorrect_substrings( - word, self.substrings) - ] for subsentence in sentence] for sentence in sentences] - sample[self.text_key] = merge_on_whitespace_tab_newline(sentences) - return sample + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + if self.tokenization: + tokenizer = get_model(self.model_key) + sentences = get_words_from_document( + text, + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + words = [ + word.replace('▁', '') for word in sentences + if self.should_keep_word_with_incorrect_substrings( + word.replace('▁', ''), self.substrings) + ] + if len(words) != len(sentences): + text = ''.join(words) + else: + sentences = split_on_newline_tab_whitespace(text) + sentences = [[[ + word for word in subsentence + if self.should_keep_word_with_incorrect_substrings( + word, self.substrings) + ] for subsentence in sentence] for sentence in sentences] + text = merge_on_whitespace_tab_newline(sentences) + + samples[self.text_key][idx] = text + + return samples diff --git a/data_juicer/ops/mapper/replace_content_mapper.py b/data_juicer/ops/mapper/replace_content_mapper.py index d16e4ec7c..324cc6357 100644 --- a/data_juicer/ops/mapper/replace_content_mapper.py +++ b/data_juicer/ops/mapper/replace_content_mapper.py @@ -11,6 +11,8 @@ class ReplaceContentMapper(Mapper): a specific regular expression pattern with a designated replacement string.""" + _batched_op = True + def __init__(self, pattern: Union[str, List[str], None] = None, repl: Union[str, List[str]] = '', @@ -42,21 +44,23 @@ def _prepare_pattern(self, pattern: str) -> re.Pattern: pattern = pattern[2:-1] return re.compile(pattern, flags=re.DOTALL) - def process(self, sample): + def process(self, samples): if self.pattern is None: - return sample - - for i, pattern in enumerate(self.compiled_patterns): - if isinstance(self.repl, list) and i < len(self.repl): - replacement = self.repl[i] - elif isinstance(self.repl, list) and i >= len(self.repl): - raise ValueError(f"pattern length: {len(self.pattern)} '" - f'must be equal to ' - f'repl length: {len(self.repl)}') - else: - replacement = self.repl - - sample[self.text_key] = pattern.sub(replacement, - sample[self.text_key]) - - return sample + return samples + + for idx, text in enumerate(samples[self.text_key]): + for i, pattern in enumerate(self.compiled_patterns): + if isinstance(self.repl, list) and i < len(self.repl): + replacement = self.repl[i] + elif isinstance(self.repl, list) and i >= len(self.repl): + raise ValueError(f"pattern length: {len(self.pattern)} '" + f'must be equal to ' + f'repl length: {len(self.repl)}') + else: + replacement = self.repl + + text = pattern.sub(replacement, text) + + samples[self.text_key][idx] = text + + return samples diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py index 522c01300..819cfd55c 100644 --- a/data_juicer/ops/mapper/sentence_split_mapper.py +++ b/data_juicer/ops/mapper/sentence_split_mapper.py @@ -14,6 +14,8 @@ class SentenceSplitMapper(Mapper): """Mapper to split text samples to sentences.""" + _batched_op = True + def __init__(self, lang: str = 'en', *args, **kwargs): """ Initialization method. @@ -26,10 +28,14 @@ def __init__(self, lang: str = 'en', *args, **kwargs): self.lang = lang self.model_key = prepare_model(model_type='nltk', lang=lang) - def process(self, sample): + def process(self, samples): nltk_model = get_model(self.model_key) - sample[self.text_key] = get_sentences_from_document( - sample[self.text_key], - model_func=nltk_model.tokenize if nltk_model else None) - return sample + + samples[self.text_key] = [ + get_sentences_from_document( + text, model_func=nltk_model.tokenize if nltk_model else None) + for text in samples[self.text_key] + ] + + return samples diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py index 6fa44b559..3102cedab 100644 --- a/data_juicer/ops/mapper/whitespace_normalization_mapper.py +++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py @@ -16,6 +16,8 @@ class WhitespaceNormalizationMapper(Mapper): https://en.wikipedia.org/wiki/Whitespace_character """ + _batched_op = True + def __init__(self, *args, **kwargs): """ Initialization method. @@ -25,13 +27,15 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - def process(self, sample): - # remove whitespaces before and after the main content - text = sample[self.text_key].strip() + def process(self, samples): + for idx, text in enumerate(samples[self.text_key]): + # remove whitespaces before and after the main content + text = text.strip() - # replace all kinds of whitespaces with ' ' - sample[self.text_key] = ''.join([ - char if char not in VARIOUS_WHITESPACES else ' ' for char in text - ]) + # replace all kinds of whitespaces with ' ' + samples[self.text_key][idx] = ''.join([ + char if char not in VARIOUS_WHITESPACES else ' ' + for char in text + ]) - return sample + return samples diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 54d9a44dc..e7e568f48 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -50,6 +50,7 @@ def test_yaml_cfg_file(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }, 'nested dict load fail, for nonparametric op') self.assertDictEqual( @@ -67,6 +68,7 @@ def test_yaml_cfg_file(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }, 'nested dict load fail, un-expected internal value') @@ -132,6 +134,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }) self.assertDictEqual( @@ -149,6 +152,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }) self.assertDictEqual( @@ -166,6 +170,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }) self.assertDictEqual( @@ -183,6 +188,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }) self.assertDictEqual( @@ -200,6 +206,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, + 'batch_size': 1, } }) diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py index d4ea828c0..3d66189d0 100644 --- a/tests/ops/filter/test_alphanumeric_filter.py +++ b/tests/ops/filter/test_alphanumeric_filter.py @@ -1,9 +1,6 @@ import unittest -from data_juicer.core.data import NestedDataset as Dataset - from data_juicer.ops.filter.alphanumeric_filter import AlphanumericFilter -from data_juicer.utils.constant import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG @@ -39,7 +36,7 @@ def test_case(self): 'text': 'emoji表情测试下😊,😸31231\n' }] dataset = self.generate_dataset(ds_list) - op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9) + op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9, batch_size=3, num_proc=1) result = self.run_single_op(dataset, op, ["text"]) self.assertDatasetEqual(result, tgt_list) @@ -67,7 +64,7 @@ def test_token_case(self): 'text': 'Do you need a cup of coffee?' }] dataset = self.generate_dataset(ds_list) - op = AlphanumericFilter(tokenization=True, min_ratio=1.5) + op = AlphanumericFilter(tokenization=True, min_ratio=1.5, batch_size=2, num_proc=1) result = self.run_single_op(dataset, op, ["text"]) self.assertDatasetEqual(result, tgt_list) diff --git a/tests/ops/filter/test_average_line_length_filter.py b/tests/ops/filter/test_average_line_length_filter.py index e294cb77e..e9768d605 100644 --- a/tests/ops/filter/test_average_line_length_filter.py +++ b/tests/ops/filter/test_average_line_length_filter.py @@ -4,50 +4,77 @@ from data_juicer.ops.filter.average_line_length_filter import \ AverageLineLengthFilter -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, InterVars from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class AverageLineLengthFilterTest(DataJuicerTestCaseBase): + text_key = 'text' + ds_list = [{ + text_key: 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + text_key: + "Today is Sund Sund Sunda and it's a happy day!\nYou know" + }, { + text_key: 'a v s e e f g a qkc' + }, { + text_key: ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + text_key: 'Do you need a cup of coffee?' + }, { + text_key: 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + text_key: 'a v s e e f g a qkc' + }, { + text_key: 'emoji表情测试下😊,😸31231\n' + }] def _run_average_line_length_filter(self, dataset: Dataset, target_list, - op): + op, context=False): if Fields.stats not in dataset.features: # TODO: # this is a temp solution, # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) - dataset = dataset.select_columns(column_names=['text']) - res_list = dataset.to_list() + if context: + dataset = dataset.add_column(name=Fields.context, + column=[{}] * dataset.num_rows) + dataset = dataset.map( + op.compute_stats, + batch_size=op.batch_size, + fn_kwargs={'context': context} + ) + dataset = dataset.filter(op.process, batch_size=op.batch_size) + dataset_test = dataset.select_columns(column_names=[self.text_key]) + res_list = dataset_test.to_list() self.assertEqual(res_list, target_list) - def test_case(self): + return dataset - ds_list = [{ - 'text': 'a=1\nb\nc=1+2+3+5\nd=6' - }, { - 'text': - "Today is Sund Sund Sunda and it's a happy day!\nYou know" - }, { - 'text': 'a v s e e f g a qkc' - }, { - 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' - }, { - 'text': 'Do you need a cup of coffee?' - }, { - 'text': 'emoji表情测试下😊,😸31231\n' - }] - tgt_list = [{ - 'text': 'a v s e e f g a qkc' - }, { - 'text': 'emoji表情测试下😊,😸31231\n' - }] - dataset = Dataset.from_list(ds_list) - op = AverageLineLengthFilter(min_len=10, max_len=20) - self._run_average_line_length_filter(dataset, tgt_list, op) + def test_case_default(self): + dataset = Dataset.from_list(self.ds_list) + op = AverageLineLengthFilter(min_len=10, max_len=20, batch_size=3) + self._run_average_line_length_filter(dataset, self.tgt_list, op, context=False) + + def test_case_context(self): + dataset = Dataset.from_list(self.ds_list) + op = AverageLineLengthFilter(min_len=10, max_len=20, batch_size=2) + dataset = self._run_average_line_length_filter(dataset, self.tgt_list, op, context=True) + + dataset = dataset.select_columns(column_names=[Fields.context]) + res_list = dataset.to_list() + + tgt_context_list = [ + { + Fields.context: { + InterVars.lines: tgt[self.text_key].splitlines() + } + } for tgt in self.tgt_list + ] + + self.assertEqual(res_list, tgt_context_list) if __name__ == '__main__': diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py index 77c1ac1d2..8c4334efd 100644 --- a/tests/ops/filter/test_character_repetition_filter.py +++ b/tests/ops/filter/test_character_repetition_filter.py @@ -18,8 +18,8 @@ def _run_character_repetition_filter(self, dataset: Dataset, target_list, # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + dataset = dataset.map(op.compute_stats, batch_size=op.batch_size, num_proc=1) + dataset = dataset.filter(op.process, batch_size=op.batch_size, num_proc=2) dataset = dataset.select_columns(column_names=['text']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -42,7 +42,11 @@ def test_case(self): 'text': '中文也是一个字算一个长度' }] dataset = Dataset.from_list(ds_list) - op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=0.4) + op = CharacterRepetitionFilter( + rep_len=5, + min_ratio=0.0, + max_ratio=0.4, + batch_size=2) self._run_character_repetition_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_maximum_line_length_filter.py b/tests/ops/filter/test_maximum_line_length_filter.py index 6f1cab7f6..6596c0f34 100644 --- a/tests/ops/filter/test_maximum_line_length_filter.py +++ b/tests/ops/filter/test_maximum_line_length_filter.py @@ -4,50 +4,77 @@ from data_juicer.ops.filter.maximum_line_length_filter import \ MaximumLineLengthFilter -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, InterVars from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class MaximumLineLengthFilterTest(DataJuicerTestCaseBase): + text_key = 'text' + ds_list = [{ + text_key: 'a=1\nb\nc=1+2+3+5\nd=6' + }, { + text_key: + "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" + }, { + text_key: 'a v s e e f g a qkc' + }, { + text_key: ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' + }, { + text_key: 'Do you need a cup of coffee?' + }, { + text_key: 'emoji表情测试下😊,😸31231\n' + }] + tgt_list = [{ + text_key: 'a v s e e f g a qkc' + }, { + text_key: 'emoji表情测试下😊,😸31231\n' + }] def _run_maximum_line_length_filter(self, dataset: Dataset, target_list, - op): + op, context=False): if Fields.stats not in dataset.features: # TODO: # this is a temp solution, # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) - dataset = dataset.select_columns(column_names=['text']) - res_list = dataset.to_list() + if context: + dataset = dataset.add_column(name=Fields.context, + column=[{}] * dataset.num_rows) + dataset = dataset.map( + op.compute_stats, + batch_size=op.batch_size, + fn_kwargs={'context': context} + ) + dataset = dataset.filter(op.process, batch_size=op.batch_size) + dataset_test = dataset.select_columns(column_names=[self.text_key]) + res_list = dataset_test.to_list() self.assertEqual(res_list, target_list) - def test_case(self): + return dataset - ds_list = [{ - 'text': 'a=1\nb\nc=1+2+3+5\nd=6' - }, { - 'text': - "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know" - }, { - 'text': 'a v s e e f g a qkc' - }, { - 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►' - }, { - 'text': 'Do you need a cup of coffee?' - }, { - 'text': 'emoji表情测试下😊,😸31231\n' - }] - tgt_list = [{ - 'text': 'a v s e e f g a qkc' - }, { - 'text': 'emoji表情测试下😊,😸31231\n' - }] - dataset = Dataset.from_list(ds_list) - op = MaximumLineLengthFilter(min_len=10, max_len=20) - self._run_maximum_line_length_filter(dataset, tgt_list, op) + def test_case_default(self): + dataset = Dataset.from_list(self.ds_list) + op = MaximumLineLengthFilter(min_len=10, max_len=20, batch_size=3) + self._run_maximum_line_length_filter(dataset, self.tgt_list, op, context=False) + + def test_case_context(self): + dataset = Dataset.from_list(self.ds_list) + op = MaximumLineLengthFilter(min_len=10, max_len=20, batch_size=2) + dataset = self._run_maximum_line_length_filter(dataset, self.tgt_list, op, context=True) + + dataset = dataset.select_columns(column_names=[Fields.context]) + res_list = dataset.to_list() + + tgt_context_list = [ + { + Fields.context: { + InterVars.lines: tgt[self.text_key].splitlines() + } + } for tgt in self.tgt_list + ] + + self.assertEqual(res_list, tgt_context_list) if __name__ == '__main__': diff --git a/tests/ops/filter/test_perplexity_filter.py b/tests/ops/filter/test_perplexity_filter.py index 07e87d17c..e3e582ff6 100644 --- a/tests/ops/filter/test_perplexity_filter.py +++ b/tests/ops/filter/test_perplexity_filter.py @@ -9,20 +9,28 @@ class PerplexityFilterTest(DataJuicerTestCaseBase): - def _run_perplexity_filter(self, dataset: Dataset, target_list, op): + def _run_perplexity_filter(self, dataset: Dataset, target_list, op, context=False): if Fields.stats not in dataset.features: # TODO: # this is a temp solution, # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) - dataset = dataset.select_columns(column_names=['text']) - res_list = dataset.to_list() + if context: + dataset = dataset.add_column(name=Fields.context, + column=[{}] * dataset.num_rows) + dataset = dataset.map( + op.compute_stats, + batch_size=op.batch_size, + fn_kwargs={'context': context} + ) + dataset = dataset.filter(op.process, batch_size=op.batch_size) + dataset_test = dataset.select_columns(column_names=['text']) + res_list = dataset_test.to_list() self.assertEqual(res_list, target_list) + return dataset - def test_en_case(self): + def _test_en_case(self, context=False): ds_list = [{ 'text': "Today is Sunday and it's a happy day!" @@ -44,8 +52,24 @@ def test_en_case(self): 'text': 'Do you need a cup of coffee?' }] dataset = Dataset.from_list(ds_list) - op = PerplexityFilter(lang='en', max_ppl=900) - self._run_perplexity_filter(dataset, tgt_list, op) + op = PerplexityFilter(lang='en', max_ppl=900, batch_size=2) + dataset= self._run_perplexity_filter(dataset, tgt_list, op, context) + if context: + dataset = dataset.select_columns(column_names=[Fields.context]) + context_list = dataset.to_list() + res_words_list = [list(context_list[i][Fields.context].values()) \ + for i in range(len(context_list))] + tgt_words_list = [ + [['▁Today', '▁is', '▁Sunday', '▁and', '▁it', "'", 's', '▁a', '▁happy', '▁day', '!']], + [['▁Do', '▁you', '▁need', '▁a', '▁cup', '▁of', '▁coffee', '?']] + ] + self.assertListEqual(res_words_list, tgt_words_list) + + def test_en_case_default(self): + self._test_en_case(context=False) + + def test_en_case_context(self): + self._test_en_case(context=True) if __name__ == '__main__': diff --git a/tests/ops/filter/test_special_characters_filter.py b/tests/ops/filter/test_special_characters_filter.py index b1dd8632e..b2e8292f1 100644 --- a/tests/ops/filter/test_special_characters_filter.py +++ b/tests/ops/filter/test_special_characters_filter.py @@ -18,8 +18,8 @@ def _run_special_characters_filter(self, dataset: Dataset, target_list, # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + dataset = dataset.map(op.compute_stats, batch_size=op.batch_size) + dataset = dataset.filter(op.process, batch_size=op.batch_size) dataset = dataset.select_columns(column_names=['text']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -49,7 +49,7 @@ def test_case(self): 'text': 'Do you need a cup of coffee?' }] dataset = Dataset.from_list(ds_list) - op = SpecialCharactersFilter(min_ratio=0.0, max_ratio=0.25) + op = SpecialCharactersFilter(min_ratio=0.0, max_ratio=0.25, batch_size=2) self._run_special_characters_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_text_length_filter.py b/tests/ops/filter/test_text_length_filter.py index 67efb6c60..2a653ff42 100644 --- a/tests/ops/filter/test_text_length_filter.py +++ b/tests/ops/filter/test_text_length_filter.py @@ -16,8 +16,8 @@ def _run_text_length_filter(self, dataset: Dataset, target_list, op): # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + dataset = dataset.map(op.compute_stats, batch_size=3) + dataset = dataset.filter(op.process, batch_size=2) dataset = dataset.select_columns(column_names=['text']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py index 0d53a164d..f099bec05 100644 --- a/tests/ops/filter/test_word_num_filter.py +++ b/tests/ops/filter/test_word_num_filter.py @@ -16,8 +16,8 @@ def _run_words_num_filter(self, dataset: Dataset, target_list, op): # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + dataset = dataset.map(op.compute_stats, batch_size=op.batch_size) + dataset = dataset.filter(op.process, batch_size=op.batch_size) dataset = dataset.select_columns(column_names=['text']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -41,7 +41,7 @@ def test_case(self): 'text': 'a v s e c s f e f g a a a ' }] dataset = Dataset.from_list(ds_list) - op = WordsNumFilter(min_num=5, max_num=15) + op = WordsNumFilter(min_num=5, max_num=15, batch_size=2) self._run_words_num_filter(dataset, tgt_list, op) def test_zh_case(self): @@ -68,7 +68,8 @@ def test_zh_case(self): op = WordsNumFilter(lang='zh', tokenization=True, min_num=10, - max_num=25) + max_num=25, + batch_size=1) self._run_words_num_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_word_repetition_filter.py b/tests/ops/filter/test_word_repetition_filter.py index f59576ef8..030189d56 100644 --- a/tests/ops/filter/test_word_repetition_filter.py +++ b/tests/ops/filter/test_word_repetition_filter.py @@ -16,8 +16,8 @@ def _run_word_repetition_filter(self, dataset: Dataset, target_list, op): # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) + dataset = dataset.map(op.compute_stats, batch_size=op.batch_size) + dataset = dataset.filter(op.process, batch_size=op.batch_size) dataset = dataset.select_columns(column_names=['text']) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -51,7 +51,11 @@ def test_en_case(self): 'This proposed a novel proposed pretraining proposed pretraining.' }] dataset = Dataset.from_list(ds_list) - op = WordRepetitionFilter(rep_len=3, min_ratio=0.0, max_ratio=0.2) + op = WordRepetitionFilter( + rep_len=3, + min_ratio=0.0, + max_ratio=0.2, + batch_size=2) self._run_word_repetition_filter(dataset, tgt_list, op) def test_zh_case(self): @@ -79,7 +83,8 @@ def test_zh_case(self): tokenization=True, rep_len=3, min_ratio=0.0, - max_ratio=0.2) + max_ratio=0.2, + batch_size=1) self._run_word_repetition_filter(dataset, tgt_list, op) diff --git a/tests/ops/mapper/test_chinese_convert_mapper.py b/tests/ops/mapper/test_chinese_convert_mapper.py index 9bbe8e8df..bc21f40fe 100644 --- a/tests/ops/mapper/test_chinese_convert_mapper.py +++ b/tests/ops/mapper/test_chinese_convert_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.chinese_convert_mapper import ChineseConvertMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self, mode='s2t'): self.op = ChineseConvertMapper(mode) def _run_chinese_convert(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_s2t(self): diff --git a/tests/ops/mapper/test_clean_copyright_mapper.py b/tests/ops/mapper/test_clean_copyright_mapper.py index 726d829f7..a236988f7 100644 --- a/tests/ops/mapper/test_clean_copyright_mapper.py +++ b/tests/ops/mapper/test_clean_copyright_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_copyright_mapper import CleanCopyrightMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = CleanCopyrightMapper() def _run_clean_copyright(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_clean_copyright(self): diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py index b3f0e5e9a..1ff7e389e 100644 --- a/tests/ops/mapper/test_clean_email_mapper.py +++ b/tests/ops/mapper/test_clean_email_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_email_mapper import CleanEmailMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class CleanEmailMapperTest(DataJuicerTestCaseBase): def _run_clean_email(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_clean_email(self): diff --git a/tests/ops/mapper/test_clean_html_mapper.py b/tests/ops/mapper/test_clean_html_mapper.py index 69249b60a..71d4e11ee 100644 --- a/tests/ops/mapper/test_clean_html_mapper.py +++ b/tests/ops/mapper/test_clean_html_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_html_mapper import CleanHtmlMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = CleanHtmlMapper() def _run_helper(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_complete_html_text(self): diff --git a/tests/ops/mapper/test_clean_ip_mapper.py b/tests/ops/mapper/test_clean_ip_mapper.py index ccbaf52b7..479228263 100644 --- a/tests/ops/mapper/test_clean_ip_mapper.py +++ b/tests/ops/mapper/test_clean_ip_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_ip_mapper import CleanIpMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class CleanIpMapperTest(DataJuicerTestCaseBase): def _run_clean_ip(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_ipv4(self): diff --git a/tests/ops/mapper/test_clean_links_mapper.py b/tests/ops/mapper/test_clean_links_mapper.py index 28e14b2d9..5efcd4acd 100644 --- a/tests/ops/mapper/test_clean_links_mapper.py +++ b/tests/ops/mapper/test_clean_links_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.clean_links_mapper import CleanLinksMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = CleanLinksMapper() def _run_clean_links(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_lower_ftp_links_text(self): diff --git a/tests/ops/mapper/test_expand_macro_mapper.py b/tests/ops/mapper/test_expand_macro_mapper.py index 68dbf047b..bdc758193 100644 --- a/tests/ops/mapper/test_expand_macro_mapper.py +++ b/tests/ops/mapper/test_expand_macro_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.expand_macro_mapper import ExpandMacroMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = ExpandMacroMapper() def _run_expand_macro(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self): diff --git a/tests/ops/mapper/test_fix_unicode_mapper.py b/tests/ops/mapper/test_fix_unicode_mapper.py index 547020b51..1e969d117 100644 --- a/tests/ops/mapper/test_fix_unicode_mapper.py +++ b/tests/ops/mapper/test_fix_unicode_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.fix_unicode_mapper import FixUnicodeMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = FixUnicodeMapper() def _run_fix_unicode(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_bad_unicode_text(self): diff --git a/tests/ops/mapper/test_punctuation_normalization_mapper.py b/tests/ops/mapper/test_punctuation_normalization_mapper.py index a69d4040e..080666ce8 100644 --- a/tests/ops/mapper/test_punctuation_normalization_mapper.py +++ b/tests/ops/mapper/test_punctuation_normalization_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.punctuation_normalization_mapper import \ PunctuationNormalizationMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = PunctuationNormalizationMapper() def _run_punctuation_normalization(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self): diff --git a/tests/ops/mapper/test_remove_bibliography_mapper.py b/tests/ops/mapper/test_remove_bibliography_mapper.py index 76096fe93..9d08c2a4d 100644 --- a/tests/ops/mapper/test_remove_bibliography_mapper.py +++ b/tests/ops/mapper/test_remove_bibliography_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_bibliography_mapper import \ RemoveBibliographyMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = RemoveBibliographyMapper() def _run_remove_bibliography(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_bibliography_case(self): diff --git a/tests/ops/mapper/test_remove_comments_mapper.py b/tests/ops/mapper/test_remove_comments_mapper.py index 81a0df5de..93a287460 100644 --- a/tests/ops/mapper/test_remove_comments_mapper.py +++ b/tests/ops/mapper/test_remove_comments_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_comments_mapper import RemoveCommentsMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class RemoveCommentsMapperTest(DataJuicerTestCaseBase): def _run_remove_comments(self, samples, op): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_tex_case(self): diff --git a/tests/ops/mapper/test_remove_header_mapper.py b/tests/ops/mapper/test_remove_header_mapper.py index c91bfe790..0196b0317 100644 --- a/tests/ops/mapper/test_remove_header_mapper.py +++ b/tests/ops/mapper/test_remove_header_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_header_mapper import RemoveHeaderMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ def setUp(self): self.op = RemoveHeaderMapper() def _run_remove_header(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self): diff --git a/tests/ops/mapper/test_remove_long_words_mapper.py b/tests/ops/mapper/test_remove_long_words_mapper.py index 533d7a717..817043979 100644 --- a/tests/ops/mapper/test_remove_long_words_mapper.py +++ b/tests/ops/mapper/test_remove_long_words_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_long_words_mapper import \ RemoveLongWordsMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -8,9 +9,11 @@ class RemoveLongWordsMapperTest(DataJuicerTestCaseBase): def _run_remove_long_words(self, samples, op): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_normal_case(self): diff --git a/tests/ops/mapper/test_remove_non_chinese_character_mapper.py b/tests/ops/mapper/test_remove_non_chinese_character_mapper.py index 283a75ab0..1ab000107 100644 --- a/tests/ops/mapper/test_remove_non_chinese_character_mapper.py +++ b/tests/ops/mapper/test_remove_non_chinese_character_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_non_chinese_character_mapper import \ RemoveNonChineseCharacterlMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -12,9 +13,11 @@ def setUp(self, keep_alphabet=True, keep_number=True, keep_punc=True): keep_punc) def _run_remove_non_chinese_character(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_remove_non_chinese_character(self): diff --git a/tests/ops/mapper/test_remove_repeat_sentences_mapper.py b/tests/ops/mapper/test_remove_repeat_sentences_mapper.py index a7fe347fe..a2a560f97 100644 --- a/tests/ops/mapper/test_remove_repeat_sentences_mapper.py +++ b/tests/ops/mapper/test_remove_repeat_sentences_mapper.py @@ -2,6 +2,7 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_repeat_sentences_mapper import \ RemoveRepeatSentencesMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,9 +11,11 @@ class RemoveRepeatSentencesMapperTest(DataJuicerTestCaseBase): def _run_helper(self, samples, op): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_text(self): diff --git a/tests/ops/mapper/test_remove_specific_chars_mapper.py b/tests/ops/mapper/test_remove_specific_chars_mapper.py index f61a3f6fc..f786db3ee 100644 --- a/tests/ops/mapper/test_remove_specific_chars_mapper.py +++ b/tests/ops/mapper/test_remove_specific_chars_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_specific_chars_mapper import \ RemoveSpecificCharsMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = RemoveSpecificCharsMapper() def _run_helper(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_complete_html_text(self): diff --git a/tests/ops/mapper/test_remove_table_text_mapper.py b/tests/ops/mapper/test_remove_table_text_mapper.py index 2be4a2453..214a36e79 100644 --- a/tests/ops/mapper/test_remove_table_text_mapper.py +++ b/tests/ops/mapper/test_remove_table_text_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_table_text_mapper import \ RemoveTableTextMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = RemoveTableTextMapper() def _run_remove_header(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_single_table_case(self): diff --git a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py index 02157ad52..f6d6d109f 100644 --- a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py +++ b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.remove_words_with_incorrect_substrings_mapper import \ RemoveWordsWithIncorrectSubstringsMapper # noqa: E501 from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -8,9 +9,11 @@ class RemoveWordsWithIncorrectSubstringsMapperTest(DataJuicerTestCaseBase): def _run_remove_words_with_incorrect_sbstrings(self, samples, op): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_en_case(self): diff --git a/tests/ops/mapper/test_replace_content_mapper.py b/tests/ops/mapper/test_replace_content_mapper.py index 64f88c888..23ddb3453 100644 --- a/tests/ops/mapper/test_replace_content_mapper.py +++ b/tests/ops/mapper/test_replace_content_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.replace_content_mapper import ReplaceContentMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class ReplaceContentMapperTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_special_char_pattern_text(self): diff --git a/tests/ops/mapper/test_sentence_split_mapper.py b/tests/ops/mapper/test_sentence_split_mapper.py index 3cdf3a977..4352fdce8 100644 --- a/tests/ops/mapper/test_sentence_split_mapper.py +++ b/tests/ops/mapper/test_sentence_split_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.sentence_split_mapper import SentenceSplitMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -7,9 +8,11 @@ class SentenceSplitMapperTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): - for sample in samples: - result = op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_en_text(self): diff --git a/tests/ops/mapper/test_whitespace_normalization_mapper.py b/tests/ops/mapper/test_whitespace_normalization_mapper.py index 985cc7076..c92516ff7 100644 --- a/tests/ops/mapper/test_whitespace_normalization_mapper.py +++ b/tests/ops/mapper/test_whitespace_normalization_mapper.py @@ -1,5 +1,6 @@ import unittest +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.whitespace_normalization_mapper import \ WhitespaceNormalizationMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -11,9 +12,11 @@ def setUp(self): self.op = WhitespaceNormalizationMapper() def _run_whitespace_normalization(self, samples): - for sample in samples: - result = self.op.process(sample) - self.assertEqual(result['text'], result['target']) + dataset = Dataset.from_list(samples) + dataset = dataset.map(self.op.process, batch_size=2) + + for data in dataset: + self.assertEqual(data['text'], data['target']) def test_case(self):