diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py
index 3e1e3ad29..4e0fcbe6e 100644
--- a/data_juicer/core/data.py
+++ b/data_juicer/core/data.py
@@ -233,7 +233,9 @@ def map(self, *args, **kargs):
called_func.__self__,
'is_batched_op') and called_func.__self__.is_batched_op():
kargs['batched'] = True
- kargs['batch_size'] = kargs.pop('batch_size', 1)
+ kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr(
+ called_func.__self__, 'is_batched_op'
+ ) and called_func.__self__.is_batched_op() else 1
else:
kargs['batched'] = False
@@ -266,12 +268,25 @@ def filter(self, *args, **kargs):
args[0] = lambda x: nested_obj_factory(x)
else:
args[0] = wrap_func_with_nested_access(args[0])
+ called_func = args[0]
else:
if 'function' not in kargs or kargs['function'] is None:
kargs['function'] = lambda x: nested_obj_factory(x)
else:
kargs['function'] = wrap_func_with_nested_access(
kargs['function'])
+ called_func = kargs['function']
+
+ # For wrapped function, try to get its unwrapped (bound) method
+ while not inspect.ismethod(called_func) and hasattr(
+ called_func, '__wrapped__'):
+ called_func = called_func.__wrapped__
+
+ # Batched is always required for fault tolerance
+ if inspect.ismethod(
+ called_func) and called_func.__self__.is_batched_op():
+ kargs['batched'] = True
+ kargs['batch_size'] = kargs.pop('batch_size', 1)
if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py
index 17235a2b8..ce964a3af 100644
--- a/data_juicer/core/ray_data.py
+++ b/data_juicer/core/ray_data.py
@@ -108,14 +108,16 @@ def _run_single_op(self, op):
self.num_proc, op.use_cuda())
num_gpus = get_num_gpus(op, op_proc)
try:
+ batch_size = getattr(op, 'batch_size',
+ 1) if op.is_batched_op() else 1
if isinstance(op, Mapper):
self.data = self.data.map_batches(op.process,
- batch_size=1,
+ batch_size=batch_size,
batch_format='pyarrow',
num_gpus=num_gpus)
elif isinstance(op, Filter):
self.data = self.data.map_batches(op.compute_stats,
- batch_size=1,
+ batch_size=batch_size,
batch_format='pyarrow',
num_gpus=num_gpus)
if op.stats_export_path is not None:
diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py
index 43433704e..b866b7f87 100644
--- a/data_juicer/ops/base_op.py
+++ b/data_juicer/ops/base_op.py
@@ -133,6 +133,7 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get('image_key', 'images')
self.audio_key = kwargs.get('audio_key', 'audios')
self.video_key = kwargs.get('video_key', 'videos')
+ self.batch_size = kwargs.get('batch_size', 1)
# whether the model can be accelerated using cuda
_accelerator = kwargs.get('accelerator', None)
@@ -241,6 +242,7 @@ def run(self, dataset, *, exporter=None, tracer=None):
self.process,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
+ batch_size=self.batch_size,
desc=self._name + '_process',
)
if tracer:
@@ -304,15 +306,18 @@ def run(self, dataset, *, exporter=None, tracer=None):
'initial_value': {}
},
num_proc=self.runtime_np(),
+ batch_size=self.batch_size,
desc='Adding new column for stats')
dataset = dataset.map(self.compute_stats,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
+ batch_size=self.batch_size,
desc=self._name + '_compute_stats')
if self.stats_export_path is not None:
exporter.export_compute_stats(dataset, self.stats_export_path)
new_dataset = dataset.filter(self.process,
num_proc=self.runtime_np(),
+ batch_size=self.batch_size,
desc=self._name + '_process')
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py
index e1cf90927..17361b29c 100644
--- a/data_juicer/ops/filter/alphanumeric_filter.py
+++ b/data_juicer/ops/filter/alphanumeric_filter.py
@@ -18,6 +18,8 @@ class AlphanumericFilter(Filter):
"""Filter to keep samples with alphabet/numeric ratio within a specific
range."""
+ _batched_op = True
+
def __init__(self,
tokenization: bool = False,
min_ratio: float = 0.25,
@@ -52,36 +54,46 @@ def __init__(self,
pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped',
return_model=False)
- def compute_stats(self, sample):
- if self.tokenization:
- if StatsKeys.alpha_token_ratio in sample[Fields.stats]:
- return sample
- alpha_count = sum(
- map(lambda char: 1
- if char.isalpha() else 0, sample[self.text_key]))
- tokenizer = get_model(self.model_key)
- token_count = len(
- get_words_from_document(
- sample[self.text_key],
- token_func=tokenizer.tokenize if tokenizer else None))
- sample[Fields.stats][StatsKeys.alpha_token_ratio] = (
- alpha_count / token_count) if token_count != 0 else 0.0
- else:
- if StatsKeys.alnum_ratio in sample[Fields.stats]:
- return sample
- alnum_count = sum(
- map(lambda char: 1
- if char.isalnum() else 0, sample[self.text_key]))
- sample[Fields.stats][StatsKeys.alnum_ratio] = (
- alnum_count / len(sample[self.text_key])) if len(
- sample[self.text_key]) != 0 else 0.0
- return sample
+ def compute_stats(self, samples):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
+
+ for idx, stat in enumerate(samples_stats):
+ cur_text = samples_list[idx]
+ if self.tokenization:
+ if StatsKeys.alpha_token_ratio in stat:
+ continue
+ alpha_count = sum(
+ map(lambda char: 1 if char.isalpha() else 0, cur_text))
+ tokenizer = get_model(self.model_key)
+ token_count = len(
+ get_words_from_document(
+ cur_text,
+ token_func=tokenizer.tokenize if tokenizer else None))
+ samples_stats[idx][StatsKeys.alpha_token_ratio] = (
+ alpha_count / token_count) if token_count != 0 else 0.0
+ else:
+ if StatsKeys.alnum_ratio in stat:
+ continue
+ alnum_count = sum(
+ map(lambda char: 1 if char.isalnum() else 0, cur_text))
+ samples_stats[idx][StatsKeys.alnum_ratio] = (
+ alnum_count / len(cur_text)) if len(cur_text) != 0 else 0.0
+
+ return samples
- def process(self, sample):
- ratio = sample[Fields.stats][
- StatsKeys.alpha_token_ratio] if self.tokenization else sample[
- Fields.stats][StatsKeys.alnum_ratio]
- if self.min_ratio <= ratio <= self.max_ratio:
- return True
+ def process(self, samples):
+ ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \
+ else StatsKeys.alnum_ratio
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_ratio <= stat[ratio_key] <= self.
+ max_ratio, samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_ratio <= samples[
+ Fields.stats][ratio_key] <= self.max_ratio:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py
index 079a6d9f3..74d624a82 100644
--- a/data_juicer/ops/filter/average_line_length_filter.py
+++ b/data_juicer/ops/filter/average_line_length_filter.py
@@ -5,13 +5,17 @@
from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_LINES
+OP_NAME = 'average_line_length_filter'
-@OPERATORS.register_module('average_line_length_filter')
-@INTER_LINES.register_module('average_line_length_filter')
+
+@OPERATORS.register_module(OP_NAME)
+@INTER_LINES.register_module(OP_NAME)
class AverageLineLengthFilter(Filter):
"""Filter to keep samples with average line length within a specific
range."""
+ _batched_op = True
+
def __init__(self,
min_len: int = 10,
max_len: int = sys.maxsize,
@@ -33,26 +37,38 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len
- def compute_stats(self, sample, context=False):
- # check if it's computed already
- if StatsKeys.avg_line_length in sample[Fields.stats]:
- return sample
-
+ def compute_stats(self, samples, context=False):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'
- if context and context_key in sample[Fields.context]:
- lines = sample[Fields.context][context_key]
- else:
- lines = sample[self.text_key].splitlines()
- if context:
- sample[Fields.context][context_key] = lines
- sample[Fields.stats][StatsKeys.avg_line_length] = \
- len(sample[self.text_key]) / len(lines) \
- if len(lines) != 0 else 0.0
- return sample
-
- def process(self, sample):
- if self.min_len <= sample[Fields.stats][
- StatsKeys.avg_line_length] <= self.max_len:
- return True
+
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.avg_line_length in stat:
+ continue
+
+ cur_text = samples_list[idx]
+ if context and context_key in samples[Fields.context][idx]:
+ lines = samples[Fields.context][idx][context_key]
+ else:
+ lines = cur_text.splitlines()
+ if context:
+ samples[Fields.context][idx][context_key] = lines
+ samples_stats[idx][StatsKeys.avg_line_length] = \
+ len(cur_text) / len(lines) if len(lines) != 0 else 0.0
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_len <= stat[StatsKeys.avg_line_length
+ ] <= self.max_len,
+ samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_len <= samples[Fields.stats][
+ StatsKeys.avg_line_length] <= self.max_len:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py
index 1fb6949ff..a0441334a 100644
--- a/data_juicer/ops/filter/character_repetition_filter.py
+++ b/data_juicer/ops/filter/character_repetition_filter.py
@@ -15,6 +15,8 @@ class CharacterRepetitionFilter(Filter):
"""Filter to keep samples with char-level n-gram repetition ratio within a
specific range."""
+ _batched_op = True
+
def __init__(self,
rep_len: PositiveInt = 10,
min_ratio: float = 0.0,
@@ -39,40 +41,54 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio
- def compute_stats(self, sample):
- # check if it's computed already
- if StatsKeys.char_rep_ratio in sample[Fields.stats]:
- return sample
+ def compute_stats(self, samples):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
+
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.char_rep_ratio in stat:
+ continue
+
+ cur_text = samples_list[idx]
+ char_ngrams = [
+ cur_text[i:i + self.n]
+ for i in range(len(cur_text) - self.n + 1)
+ ]
+ freq_char_ngrams = {}
+ for char_ngram in char_ngrams:
+ freq_char_ngrams[char_ngram] = (
+ freq_char_ngrams.get(char_ngram, 0) + 1)
- char_ngrams = [
- sample[self.text_key][i:i + self.n]
- for i in range(len(sample[self.text_key]) - self.n + 1)
- ]
- freq_char_ngrams = {}
- for char_ngram in char_ngrams:
- freq_char_ngrams[char_ngram] = (
- freq_char_ngrams.get(char_ngram, 0) + 1)
+ if len(freq_char_ngrams) == 0:
+ samples_stats[idx][StatsKeys.char_rep_ratio] = 0.0
+ continue
- if len(freq_char_ngrams) == 0:
- sample[Fields.stats][StatsKeys.char_rep_ratio] = 0.0
- return sample
+ freq_char_ngrams = sorted(list(freq_char_ngrams.values()),
+ reverse=True)
+ num_no_rep_char_ngrams = len(
+ [el for el in freq_char_ngrams if el == 1])
+ num_rep_char_ngrams = min(
+ int(np.sqrt(len(freq_char_ngrams))),
+ len(freq_char_ngrams) - num_no_rep_char_ngrams,
+ )
+ samples_stats[idx][StatsKeys.char_rep_ratio] = (
+ sum(freq_char_ngrams[:num_rep_char_ngrams]) /
+ sum(freq_char_ngrams)) if sum(freq_char_ngrams) != 0 else 0.0
- freq_char_ngrams = sorted(list(freq_char_ngrams.values()),
- reverse=True)
- num_no_rep_char_ngrams = len(
- [el for el in freq_char_ngrams if el == 1])
- num_rep_char_ngrams = min(
- int(np.sqrt(len(freq_char_ngrams))),
- len(freq_char_ngrams) - num_no_rep_char_ngrams,
- )
- sample[Fields.stats][StatsKeys.char_rep_ratio] = (sum(
- freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams)) \
- if sum(freq_char_ngrams) != 0 else 0.0
- return sample
+ return samples
- def process(self, sample):
- if self.min_ratio <= sample[Fields.stats][StatsKeys.char_rep_ratio] \
- <= self.max_ratio:
- return True
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_ratio <= stat[
+ StatsKeys.char_rep_ratio] <= self.max_ratio,
+ samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_ratio <= samples[Fields.stats][
+ StatsKeys.char_rep_ratio] <= self.max_ratio:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py
index 2f2a4513e..146cfb0a2 100644
--- a/data_juicer/ops/filter/maximum_line_length_filter.py
+++ b/data_juicer/ops/filter/maximum_line_length_filter.py
@@ -5,13 +5,17 @@
from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_LINES
+OP_NAME = 'maximum_line_length_filter'
-@OPERATORS.register_module('maximum_line_length_filter')
-@INTER_LINES.register_module('maximum_line_length_filter')
+
+@OPERATORS.register_module(OP_NAME)
+@INTER_LINES.register_module(OP_NAME)
class MaximumLineLengthFilter(Filter):
"""Filter to keep samples with maximum line length within a specific
range."""
+ _batched_op = True
+
def __init__(self,
min_len: int = 10,
max_len: int = sys.maxsize,
@@ -33,26 +37,39 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len
- def compute_stats(self, sample, context=False):
- # check if it's computed already
- if StatsKeys.max_line_length in sample[Fields.stats]:
- return sample
-
+ def compute_stats(self, samples, context=False):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'
- if context and context_key in sample[Fields.context]:
- lines = sample[Fields.context][context_key]
- else:
- lines = sample[self.text_key].splitlines()
- if context:
- sample[Fields.context][context_key] = lines
- line_lengths = list(map(len, lines))
- sample[Fields.stats][StatsKeys.max_line_length] = max(
- line_lengths) if line_lengths else 0
- return sample
-
- def process(self, sample):
- if self.min_len <= sample[Fields.stats][
- StatsKeys.max_line_length] <= self.max_len:
- return True
+
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.max_line_length in stat:
+ continue
+
+ if context and context_key in samples[Fields.context][idx]:
+ lines = samples[Fields.context][idx][context_key]
+ else:
+ lines = samples_list[idx].splitlines()
+ if context:
+ samples[Fields.context][idx][context_key] = lines
+ line_lengths = list(map(len, lines))
+ samples_stats[idx][StatsKeys.max_line_length] = max(
+ line_lengths) if line_lengths else 0
+
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_len <= stat[StatsKeys.max_line_length
+ ] <= self.max_len,
+ samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_len <= samples[Fields.stats][
+ StatsKeys.max_line_length] <= self.max_len:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py
index 5d0b396f9..287d15a11 100644
--- a/data_juicer/ops/filter/perplexity_filter.py
+++ b/data_juicer/ops/filter/perplexity_filter.py
@@ -23,6 +23,8 @@ class PerplexityFilter(Filter):
"""Filter to keep samples with perplexity score less than a specific max
value."""
+ _batched_op = True
+
def __init__(self,
lang: str = 'en',
max_ppl: float = 1500,
@@ -44,33 +46,42 @@ def __init__(self,
lang=lang)
self.kl_model_key = prepare_model(model_type='kenlm', lang=lang)
- def compute_stats(self, sample, context=False):
- # check if it's computed already
- if StatsKeys.perplexity in sample[Fields.stats]:
- return sample
-
- # tokenization
+ def compute_stats(self, samples, context=False):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.sp_model_key}'
- if context and words_key in sample[Fields.context]:
- words = sample[Fields.context][words_key]
- else:
- tokenizer = get_model(self.sp_model_key)
- words = get_words_from_document(
- sample[self.text_key],
- token_func=tokenizer.encode_as_pieces if tokenizer else None)
- if context:
- sample[Fields.context][words_key] = words
- text = ' '.join(words)
- # compute perplexity
- logits, length = 0, 0
- kenlm_model = get_model(self.kl_model_key)
- for line in text.splitlines():
- logits += kenlm_model.score(line)
- length += (len(line.split()) + 1)
- ppl = (10.0**(-logits / length)) if length != 0 else 0.0
- sample[Fields.stats][StatsKeys.perplexity] = round(ppl, 1)
- return sample
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.perplexity in stat:
+ continue
+ # tokenization
+ if context and words_key in samples[Fields.context][idx]:
+ words = samples[Fields.context][idx][words_key]
+ else:
+ tokenizer = get_model(self.sp_model_key)
+ words = get_words_from_document(
+ samples_list[idx],
+ token_func=tokenizer.encode_as_pieces
+ if tokenizer else None)
+ if context:
+ samples[Fields.context][idx][words_key] = words
+ text = ' '.join(words)
+ # compute perplexity
+ logits, length = 0, 0
+ kenlm_model = get_model(self.kl_model_key)
+ for line in text.splitlines():
+ logits += kenlm_model.score(line)
+ length += (len(line.split()) + 1)
+ ppl = (10.0**(-logits / length)) if length != 0 else 0.0
+ samples_stats[idx][StatsKeys.perplexity] = round(ppl, 1)
- def process(self, sample):
- return sample[Fields.stats][StatsKeys.perplexity] <= self.max_ppl
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl,
+ samples[Fields.stats]))
+ else:
+ return samples[Fields.stats][StatsKeys.perplexity] <= self.max_ppl
diff --git a/data_juicer/ops/filter/special_characters_filter.py b/data_juicer/ops/filter/special_characters_filter.py
index dc9ef1ed6..0b56f390e 100644
--- a/data_juicer/ops/filter/special_characters_filter.py
+++ b/data_juicer/ops/filter/special_characters_filter.py
@@ -13,6 +13,8 @@ class SpecialCharactersFilter(Filter):
"""Filter to keep samples with special-char ratio within a specific
range."""
+ _batched_op = True
+
def __init__(self,
min_ratio: float = 0.0,
max_ratio: float = 0.25,
@@ -34,23 +36,34 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio
- def compute_stats(self, sample):
- # check if it's computed already
- if StatsKeys.special_char_ratio in sample[Fields.stats]:
- return sample
-
- # get ratio of special characters
- sample[Fields.stats][StatsKeys.special_char_ratio] = (
- len([c
- for c in sample[self.text_key] if c in SPECIAL_CHARACTERS]) /
- len(sample[self.text_key])) if len(
- sample[self.text_key]) != 0 else 0.0
- return sample
-
- def process(self, sample):
- if self.min_ratio <= \
- sample[Fields.stats][StatsKeys.special_char_ratio] \
- <= self.max_ratio:
- return True
+ def compute_stats(self, samples):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
+
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.special_char_ratio in stat:
+ continue
+ cur_text = samples_list[idx]
+ # get ratio of special characters
+ samples_stats[idx][StatsKeys.special_char_ratio] = (
+ len([c for c in cur_text if c in SPECIAL_CHARACTERS]) /
+ len(cur_text)) if len(cur_text) != 0 else 0.0
+
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_ratio <= stat[
+ StatsKeys.special_char_ratio] <= self.max_ratio,
+ samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_ratio <= \
+ samples[Fields.stats][StatsKeys.special_char_ratio] \
+ <= self.max_ratio:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py
index 6fa966889..ec61f8304 100644
--- a/data_juicer/ops/filter/text_length_filter.py
+++ b/data_juicer/ops/filter/text_length_filter.py
@@ -10,6 +10,8 @@ class TextLengthFilter(Filter):
"""Filter to keep samples with total text length within a specific
range."""
+ _batched_op = True
+
def __init__(self,
min_len: int = 10,
max_len: int = sys.maxsize,
@@ -31,17 +33,28 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len
- def compute_stats(self, sample):
- # check if it's computed already
- if StatsKeys.text_len in sample[Fields.stats]:
- return sample
-
- sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key])
- return sample
-
- def process(self, sample):
- if self.min_len <= sample[Fields.stats][
- StatsKeys.text_len] <= self.max_len:
- return True
+ def compute_stats(self, samples):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
+ for i, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.text_len in stat:
+ continue
+ else:
+ samples_stats[i][StatsKeys.text_len] = len(samples_list[i])
+
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_len <= stat[StatsKeys.text_len] <=
+ self.max_len, samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_len <= samples[Fields.stats][
+ StatsKeys.text_len] <= self.max_len:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py
index 3009c55f9..71f806e25 100644
--- a/data_juicer/ops/filter/word_repetition_filter.py
+++ b/data_juicer/ops/filter/word_repetition_filter.py
@@ -25,6 +25,8 @@ class WordRepetitionFilter(Filter):
"""Filter to keep samples with word-level n-gram repetition ratio within a
specific range."""
+ _batched_op = True
+
def __init__(self,
lang: str = 'en',
tokenization: bool = False,
@@ -59,57 +61,70 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)
- def compute_stats(self, sample, context=False):
- # check if it's computed already
- if StatsKeys.word_rep_ratio in sample[Fields.stats]:
- return sample
-
- # try to get words from context
+ def compute_stats(self, samples, context=False):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.model_key}'
- if context and words_key in sample[Fields.context]:
- words = sample[Fields.context][words_key]
- else:
- tokenizer = get_model(self.model_key)
- words = get_words_from_document(
- sample[self.text_key],
- token_func=tokenizer.encode_as_pieces if tokenizer else None)
- if context:
- sample[Fields.context][words_key] = words
-
- # try to get refined words from context
- refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \
- f'False-[2]-'
- if context and refined_words_key in sample[Fields.context]:
- words = sample[Fields.context][refined_words_key]
- else:
- words = words_refinement(words,
- lower_case=True,
- strip_chars=SPECIAL_CHARACTERS)
- if context:
- sample[Fields.context][refined_words_key] = words
- word_ngrams = [
- ' '.join(words[i:i + self.n])
- for i in range(len(words) - self.n + 1)
- ]
- freq_word_ngrams = {}
- for word_ngram in word_ngrams:
- freq_word_ngrams[word_ngram] = (
- freq_word_ngrams.get(word_ngram, 0) + 1)
-
- if len(freq_word_ngrams) == 0:
- sample[Fields.stats][StatsKeys.word_rep_ratio] = 0.0
- return sample
-
- freq_word_ngrams = list(freq_word_ngrams.values())
- rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1]
- sample[Fields.stats][StatsKeys.word_rep_ratio] = (
- sum(rep_more_than_one) /
- sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0
- return sample
-
- def process(self, sample):
- if self.min_ratio <= sample[Fields.stats][StatsKeys.word_rep_ratio] \
- <= self.max_ratio:
- return True
+
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.word_rep_ratio in stat:
+ continue
+ # try to get words from context
+ if context and words_key in samples[Fields.context][idx]:
+ words = samples[Fields.context][idx][words_key]
+ else:
+ tokenizer = get_model(self.model_key)
+ words = get_words_from_document(
+ samples_list[idx],
+ token_func=tokenizer.encode_as_pieces
+ if tokenizer else None)
+ if context:
+ samples[Fields.context][idx][words_key] = words
+
+ # try to get refined words from context
+ refined_words_key = f'{InterVars.refined_words}-' \
+ f'True-SPECIAL_CHARS-False-[2]-'
+ if context and refined_words_key in samples[Fields.context][idx]:
+ words = samples[Fields.context][idx][refined_words_key]
+ else:
+ words = words_refinement(words,
+ lower_case=True,
+ strip_chars=SPECIAL_CHARACTERS)
+ if context:
+ samples[Fields.context][idx][refined_words_key] = words
+ word_ngrams = [
+ ' '.join(words[i:i + self.n])
+ for i in range(len(words) - self.n + 1)
+ ]
+ freq_word_ngrams = {}
+ for word_ngram in word_ngrams:
+ freq_word_ngrams[word_ngram] = (
+ freq_word_ngrams.get(word_ngram, 0) + 1)
+
+ if len(freq_word_ngrams) == 0:
+ samples_stats[idx][StatsKeys.word_rep_ratio] = 0.0
+ continue
+
+ freq_word_ngrams = list(freq_word_ngrams.values())
+ rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1]
+ samples_stats[idx][StatsKeys.word_rep_ratio] = (
+ sum(rep_more_than_one) /
+ sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0
+
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_ratio <= stat[
+ StatsKeys.word_rep_ratio] <= self.max_ratio,
+ samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_ratio <= samples[Fields.stats][
+ StatsKeys.word_rep_ratio] <= self.max_ratio:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py
index 7d354cb54..07eb8e2b7 100644
--- a/data_juicer/ops/filter/words_num_filter.py
+++ b/data_juicer/ops/filter/words_num_filter.py
@@ -21,6 +21,8 @@ class WordsNumFilter(Filter):
"""Filter to keep samples with total words number within a specific
range."""
+ _batched_op = True
+
def __init__(self,
lang: str = 'en',
tokenization: bool = False,
@@ -52,28 +54,40 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)
- def compute_stats(self, sample, context=False):
- # check if it's computed already
- if StatsKeys.num_words in sample[Fields.stats]:
- return sample
-
+ def compute_stats(self, samples, context=False):
+ samples_list = samples[self.text_key]
+ samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.model_key}'
- if context and words_key in sample[Fields.context]:
- words = sample[Fields.context][words_key]
- else:
- tokenizer = get_model(self.model_key)
- words = get_words_from_document(
- sample[self.text_key],
- token_func=tokenizer.encode_as_pieces if tokenizer else None)
- if context:
- sample[Fields.context][words_key] = words
- words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS)
- sample[Fields.stats][StatsKeys.num_words] = len(words)
- return sample
- def process(self, sample):
- if self.min_num <= sample[Fields.stats][
- StatsKeys.num_words] <= self.max_num:
- return True
+ for idx, stat in enumerate(samples_stats):
+ # check if it's computed already
+ if StatsKeys.num_words in stat:
+ continue
+ if context and words_key in samples[Fields.context][idx]:
+ words = samples[Fields.context][idx][words_key]
+ else:
+ tokenizer = get_model(self.model_key)
+ words = get_words_from_document(
+ samples_list[idx],
+ token_func=tokenizer.encode_as_pieces
+ if tokenizer else None)
+ if context:
+ samples[Fields.context][idx][words_key] = words
+ words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS)
+ samples_stats[idx][StatsKeys.num_words] = len(words)
+
+ return samples
+
+ def process(self, samples):
+ if isinstance(samples[Fields.stats], list):
+ return list(
+ map(
+ lambda stat: self.min_num <= stat[StatsKeys.num_words] <=
+ self.max_num, samples[Fields.stats]))
else:
- return False
+ # single sample for ray filter
+ if self.min_num <= samples[Fields.stats][
+ StatsKeys.num_words] <= self.max_num:
+ return True
+ else:
+ return False
diff --git a/data_juicer/ops/mapper/chinese_convert_mapper.py b/data_juicer/ops/mapper/chinese_convert_mapper.py
index 818f1b1d4..9236ddaa2 100644
--- a/data_juicer/ops/mapper/chinese_convert_mapper.py
+++ b/data_juicer/ops/mapper/chinese_convert_mapper.py
@@ -27,6 +27,8 @@ class ChineseConvertMapper(Mapper):
"""Mapper to convert Chinese between Traditional Chinese, Simplified Chinese
and Japanese Kanji."""
+ _batched_op = True
+
def __init__(self, mode: str = 's2t', *args, **kwargs):
"""
Initialization method.
@@ -82,8 +84,10 @@ def __init__(self, mode: str = 's2t', *args, **kwargs):
self.mode = mode
prepare_converter(self.mode)
- def process(self, sample):
+ def process(self, samples):
prepare_converter(self.mode)
- sample[self.text_key] = OPENCC_CONVERTER.convert(sample[self.text_key])
- return sample
+ samples[self.text_key] = list(
+ map(lambda text: OPENCC_CONVERTER.convert(text),
+ samples[self.text_key]))
+ return samples
diff --git a/data_juicer/ops/mapper/clean_copyright_mapper.py b/data_juicer/ops/mapper/clean_copyright_mapper.py
index dabb0cd40..3bf6fcbdf 100644
--- a/data_juicer/ops/mapper/clean_copyright_mapper.py
+++ b/data_juicer/ops/mapper/clean_copyright_mapper.py
@@ -12,6 +12,8 @@ class CleanCopyrightMapper(Mapper):
"""Mapper to clean copyright comments at the beginning of the text
samples."""
+ _batched_op = True
+
def __init__(self, *args, **kwargs):
"""
Initialization method.
@@ -23,21 +25,19 @@ def __init__(self, *args, **kwargs):
self.pat = re.compile('/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/')
self.cpat = re.compile('copyright', re.IGNORECASE)
- def process(self, sample):
-
- r = self.pat.search(sample[self.text_key])
+ def _process_single_sample(self, sample):
+ r = self.pat.search(sample)
if r:
# found one, now see if it contains "copyright", if so strip it
span = r.span()
- sub = sample[self.text_key][span[0]:span[1]]
+ sub = sample[span[0]:span[1]]
if self.cpat.search(sub):
# cut it
- sample[self.text_key] = sample[
- self.text_key][:span[0]] + sample[self.text_key][span[1]:]
+ sample = sample[:span[0]] + sample[span[1]:]
return sample
- lines = sample[self.text_key].split('\n')
+ lines = sample.split('\n')
skip = 0
# Greedy replace any file that begins with comment block, most
@@ -51,5 +51,11 @@ def process(self, sample):
if skip:
# we skipped, consume it
- sample[self.text_key] = '\n'.join(lines[skip:])
+ sample = '\n'.join(lines[skip:])
return sample
+
+ def process(self, samples):
+ samples[self.text_key] = list(
+ map(lambda text: self._process_single_sample(text),
+ samples[self.text_key]))
+ return samples
diff --git a/data_juicer/ops/mapper/clean_email_mapper.py b/data_juicer/ops/mapper/clean_email_mapper.py
index b8d2a1cbb..90dcca60f 100644
--- a/data_juicer/ops/mapper/clean_email_mapper.py
+++ b/data_juicer/ops/mapper/clean_email_mapper.py
@@ -9,6 +9,8 @@
class CleanEmailMapper(Mapper):
"""Mapper to clean email in text samples."""
+ _batched_op = True
+
def __init__(self,
pattern: Optional[str] = None,
repl: str = '',
@@ -34,13 +36,13 @@ def __init__(self,
self.repl = repl
- def process(self, sample):
-
- if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ if not re.search(self.pattern, text, flags=re.DOTALL):
+ continue
+ samples[self.text_key][idx] = re.sub(pattern=self.pattern,
+ repl=self.repl,
+ string=text,
+ flags=re.DOTALL)
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=self.repl,
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ return samples
diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py
index 5c2c30c57..d959cc85f 100644
--- a/data_juicer/ops/mapper/clean_html_mapper.py
+++ b/data_juicer/ops/mapper/clean_html_mapper.py
@@ -16,6 +16,8 @@
class CleanHtmlMapper(Mapper):
"""Mapper to clean html code in text samples."""
+ _batched_op = True
+
def __init__(self, *args, **kwargs):
"""
Initialization method.
@@ -25,7 +27,7 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)
- def process(self, sample):
+ def process(self, samples):
def _clean_html(raw_html):
raw_html = raw_html.replace('
', '\n*')
@@ -35,5 +37,6 @@ def _clean_html(raw_html):
parser = HTMLParser(raw_html)
return parser.text()
- sample[self.text_key] = _clean_html(sample[self.text_key])
- return sample
+ samples[self.text_key] = list(
+ map(lambda text: _clean_html(text), samples[self.text_key]))
+ return samples
diff --git a/data_juicer/ops/mapper/clean_ip_mapper.py b/data_juicer/ops/mapper/clean_ip_mapper.py
index b36d13aae..709037ddd 100644
--- a/data_juicer/ops/mapper/clean_ip_mapper.py
+++ b/data_juicer/ops/mapper/clean_ip_mapper.py
@@ -9,6 +9,8 @@
class CleanIpMapper(Mapper):
"""Mapper to clean ipv4 and ipv6 address in text samples."""
+ _batched_op = True
+
def __init__(self,
pattern: Optional[str] = None,
repl: str = '',
@@ -38,13 +40,12 @@ def __init__(self,
self.pattern = pattern[2:-1]
self.repl = repl
- def process(self, sample):
-
- if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
- return sample
-
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=self.repl,
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ if not re.search(self.pattern, text, flags=re.DOTALL):
+ continue
+ samples[self.text_key][idx] = re.sub(pattern=self.pattern,
+ repl=self.repl,
+ string=text,
+ flags=re.DOTALL)
+ return samples
diff --git a/data_juicer/ops/mapper/clean_links_mapper.py b/data_juicer/ops/mapper/clean_links_mapper.py
index ebeac8668..f08abc78f 100644
--- a/data_juicer/ops/mapper/clean_links_mapper.py
+++ b/data_juicer/ops/mapper/clean_links_mapper.py
@@ -12,6 +12,8 @@
class CleanLinksMapper(Mapper):
"""Mapper to clean links like http/https/ftp in text samples."""
+ _batched_op = True
+
def __init__(self,
pattern: Optional[str] = None,
repl: str = '',
@@ -44,13 +46,13 @@ def __init__(self,
self.pattern = pattern[2:-1]
self.repl = repl
- def process(self, sample):
-
- if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ if not re.search(self.pattern, text, flags=re.DOTALL):
+ continue
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=self.repl,
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ samples[self.text_key][idx] = re.sub(pattern=self.pattern,
+ repl=self.repl,
+ string=text,
+ flags=re.DOTALL)
+ return samples
diff --git a/data_juicer/ops/mapper/expand_macro_mapper.py b/data_juicer/ops/mapper/expand_macro_mapper.py
index 2f5d7fe83..b83455103 100644
--- a/data_juicer/ops/mapper/expand_macro_mapper.py
+++ b/data_juicer/ops/mapper/expand_macro_mapper.py
@@ -12,6 +12,8 @@ class ExpandMacroMapper(Mapper):
"""Mapper to expand macro definitions in the document body of Latex
samples."""
+ _batched_op = True
+
def __init__(self, *args, **kwargs):
"""
Initialization method.
@@ -55,26 +57,29 @@ def _build_non_arg_macros_dict(self, file_content):
macros[macro_name] = macro_val
return macros
- def process(self, sample):
- non_arg_macros = self._build_non_arg_macros_dict(sample[self.text_key])
-
- # TODO: macros that take arguments are not supported yet
- arg_macros = {}
-
- # inline-expand all non-arg macros
- for macro_name, macro_value in non_arg_macros.items():
- sample[self.text_key] = re.sub(
- # make pattern grouped to make sure that the macro is not part
- # of a longer alphanumeric word
- pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])',
- # replace the macro with its value and add back the character
- # that was matched after the macro
- repl=macro_value + r'\2',
- string=sample[self.text_key])
-
- # inline-expand all macros that use args
- # TODO: inline-expand macros with args
- for macro_name, macro_value in arg_macros.items():
- pass
-
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ non_arg_macros = self._build_non_arg_macros_dict(text)
+
+ # TODO: macros that take arguments are not supported yet
+ arg_macros = {}
+
+ # inline-expand all non-arg macros
+ for macro_name, macro_value in non_arg_macros.items():
+ text = re.sub(
+ # make pattern grouped to make sure that the macro
+ # is not part of a longer alphanumeric word
+ pattern=r'(' + macro_name + r')' + r'([^a-zA-Z0-9])',
+ # replace the macro with its value and add back the
+ # character that was matched after the macro
+ repl=macro_value + r'\2',
+ string=text)
+
+ # inline-expand all macros that use args
+ # TODO: inline-expand macros with args
+ for macro_name, macro_value in arg_macros.items():
+ pass
+
+ samples[self.text_key][idx] = text
+
+ return samples
diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py
index a7d06da3a..4ca71c30a 100644
--- a/data_juicer/ops/mapper/fix_unicode_mapper.py
+++ b/data_juicer/ops/mapper/fix_unicode_mapper.py
@@ -12,6 +12,8 @@
class FixUnicodeMapper(Mapper):
"""Mapper to fix unicode errors in text samples."""
+ _batched_op = True
+
def __init__(self, normalization: str = None, *args, **kwargs):
"""
Initialization method.
@@ -33,7 +35,10 @@ def __init__(self, normalization: str = None, *args, **kwargs):
'supported. Can only be one of '
'["NFC", "NFKC", "NFD", "NFKD"]')
- def process(self, sample):
- sample[self.text_key] = ftfy.fix_text(sample[self.text_key],
- normalization=self.normalization)
- return sample
+ def process(self, samples):
+ samples[self.text_key] = list(
+ map(
+ lambda text: ftfy.fix_text(text,
+ normalization=self.normalization),
+ samples[self.text_key]))
+ return samples
diff --git a/data_juicer/ops/mapper/punctuation_normalization_mapper.py b/data_juicer/ops/mapper/punctuation_normalization_mapper.py
index b6640e9eb..6531833a3 100644
--- a/data_juicer/ops/mapper/punctuation_normalization_mapper.py
+++ b/data_juicer/ops/mapper/punctuation_normalization_mapper.py
@@ -10,6 +10,8 @@ class PunctuationNormalizationMapper(Mapper):
"""Mapper to normalize unicode punctuations to English punctuations in text
samples."""
+ _batched_op = True
+
def __init__(self, *args, **kwargs):
"""
Initialization method.
@@ -55,8 +57,10 @@ def __init__(self, *args, **kwargs):
'►': '-',
}
- def process(self, sample):
- sample[self.text_key] = ''.join([
- self.punctuation_unicode.get(c, c) for c in sample[self.text_key]
- ])
- return sample
+ def process(self, samples):
+ samples[self.text_key] = list(
+ map(
+ lambda text: ''.join(
+ [self.punctuation_unicode.get(c, c) for c in text]),
+ samples[self.text_key]))
+ return samples
diff --git a/data_juicer/ops/mapper/remove_bibliography_mapper.py b/data_juicer/ops/mapper/remove_bibliography_mapper.py
index 2ce852d66..d2a2bf342 100644
--- a/data_juicer/ops/mapper/remove_bibliography_mapper.py
+++ b/data_juicer/ops/mapper/remove_bibliography_mapper.py
@@ -12,6 +12,8 @@ class RemoveBibliographyMapper(Mapper):
"""Mapper to remove bibliography at the end of documents in Latex
samples."""
+ _batched_op = True
+
def __init__(self, *args, **kwargs):
"""
Initialization method.
@@ -27,9 +29,12 @@ def __init__(self, *args, **kwargs):
self.pattern += r'\\bibliography\{.*\}'
self.pattern += r').*$'
- def process(self, sample):
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=r'',
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ def process(self, samples):
+ samples[self.text_key] = list(
+ map(
+ lambda text: re.sub(pattern=self.pattern,
+ repl=r'',
+ string=text,
+ flags=re.DOTALL), samples[self.text_key]))
+
+ return samples
diff --git a/data_juicer/ops/mapper/remove_comments_mapper.py b/data_juicer/ops/mapper/remove_comments_mapper.py
index c5f083c14..09fe4e5ef 100644
--- a/data_juicer/ops/mapper/remove_comments_mapper.py
+++ b/data_juicer/ops/mapper/remove_comments_mapper.py
@@ -17,6 +17,8 @@ class RemoveCommentsMapper(Mapper):
Only support 'tex' for now.
"""
+ _batched_op = True
+
def __init__(self,
doc_type: Union[str, List[str]] = 'tex',
inline: bool = True,
@@ -37,19 +39,23 @@ def __init__(self,
self.inline = inline
self.multiline = multiline
- def process(self, sample):
+ def process(self, samples):
# TODO: remove different comments by sample type
- if self.inline:
- # remove all in comments within a line
- sample[self.text_key] = re.sub(pattern=r'[^\\]%.+$',
- repl=r'',
- string=sample[self.text_key],
- flags=re.MULTILINE)
-
- if self.multiline:
- sample[self.text_key] = re.sub(pattern=r'(?m)^%.*\n?',
- repl=r'',
- string=sample[self.text_key],
- flags=re.MULTILINE)
- return sample
+ for idx, text in enumerate(samples[self.text_key]):
+ if self.inline:
+ # remove all in comments within a line
+ text = re.sub(pattern=r'[^\\]%.+$',
+ repl=r'',
+ string=text,
+ flags=re.MULTILINE)
+
+ if self.multiline:
+ text = re.sub(pattern=r'(?m)^%.*\n?',
+ repl=r'',
+ string=text,
+ flags=re.MULTILINE)
+
+ samples[self.text_key][idx] = text
+
+ return samples
diff --git a/data_juicer/ops/mapper/remove_header_mapper.py b/data_juicer/ops/mapper/remove_header_mapper.py
index 8371d2f99..bb967e929 100644
--- a/data_juicer/ops/mapper/remove_header_mapper.py
+++ b/data_juicer/ops/mapper/remove_header_mapper.py
@@ -12,6 +12,8 @@ class RemoveHeaderMapper(Mapper):
"""Mapper to remove headers at the beginning of documents in Latex
samples."""
+ _batched_op = True
+
def __init__(self, drop_no_head: bool = True, *args, **kwargs):
"""
Initialization method.
@@ -34,15 +36,17 @@ def __init__(self, drop_no_head: bool = True, *args, **kwargs):
self.drop_no_head = drop_no_head
- def process(self, sample):
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ if not re.search(self.pattern, text, flags=re.DOTALL):
+ if self.drop_no_head:
+ text = ''
+ continue
+ text = re.sub(pattern=self.pattern,
+ repl=r'\2',
+ string=text,
+ flags=re.DOTALL)
- if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
- if self.drop_no_head:
- sample[self.text_key] = ''
- return sample
+ samples[self.text_key][idx] = text
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=r'\2',
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ return samples
diff --git a/data_juicer/ops/mapper/remove_long_words_mapper.py b/data_juicer/ops/mapper/remove_long_words_mapper.py
index ff8fa2d29..5aea47516 100644
--- a/data_juicer/ops/mapper/remove_long_words_mapper.py
+++ b/data_juicer/ops/mapper/remove_long_words_mapper.py
@@ -13,6 +13,8 @@
class RemoveLongWordsMapper(Mapper):
"""Mapper to remove long words within a specific range."""
+ _batched_op = True
+
def __init__(self,
min_len: int = 1,
max_len: int = sys.maxsize,
@@ -41,11 +43,13 @@ def should_keep_long_word(self, word):
else:
return False
- def process(self, sample):
-
- sentences = split_on_newline_tab_whitespace(sample[self.text_key])
- sentences = [[[
- word for word in subsentence if self.should_keep_long_word(word)
- ] for subsentence in sentence] for sentence in sentences]
- sample[self.text_key] = merge_on_whitespace_tab_newline(sentences)
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ sentences = split_on_newline_tab_whitespace(text)
+ sentences = [[[
+ word for word in subsentence
+ if self.should_keep_long_word(word)
+ ] for subsentence in sentence] for sentence in sentences]
+ samples[self.text_key][idx] = merge_on_whitespace_tab_newline(
+ sentences)
+ return samples
diff --git a/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py b/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py
index 3e6cd494d..371697efc 100644
--- a/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py
+++ b/data_juicer/ops/mapper/remove_non_chinese_character_mapper.py
@@ -7,6 +7,8 @@
class RemoveNonChineseCharacterlMapper(Mapper):
"""Mapper to remove non chinese Character in text samples."""
+ _batched_op = True
+
def __init__(self,
keep_alphabet: bool = True,
keep_number: bool = True,
@@ -33,13 +35,13 @@ def __init__(self,
else:
self.pattern += u']'
- def process(self, sample):
-
- if not re.search(self.pattern, sample[self.text_key], flags=re.DOTALL):
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ if not re.search(self.pattern, text, flags=re.DOTALL):
+ continue
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=r'',
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ samples[self.text_key][idx] = re.sub(pattern=self.pattern,
+ repl=r'',
+ string=text,
+ flags=re.DOTALL)
+ return samples
diff --git a/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py b/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py
index a1069d24d..add0a719e 100644
--- a/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py
+++ b/data_juicer/ops/mapper/remove_repeat_sentences_mapper.py
@@ -15,6 +15,8 @@ def split_sentence(text):
class RemoveRepeatSentencesMapper(Mapper):
"""Mapper to remove repeat sentences in text samples."""
+ _batched_op = True
+
def __init__(self,
lowercase: bool = False,
ignore_special_character: bool = True,
@@ -43,28 +45,29 @@ def __init__(self,
self.remove_regex = re.compile(r'[^a-zA-Z0-9\u4e00-\u9fa5\n\t ]'
) if ignore_special_character else None
- def process(self, sample):
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ lines = [e for e in text.split('\n')]
+ new_lines = []
+ hash_set = set([])
+ for line in lines:
+ new_sent = ''
+ if line:
+ sentences = split_sentence(line)
+ for sentence in sentences:
+ copy = sentence.strip()
+ if self.lowercase:
+ copy = copy.lower()
+ if self.remove_regex:
+ copy = self.remove_regex.sub('', copy)
- lines = [e for e in sample[self.text_key].split('\n')]
- new_lines = []
- hash_set = set([])
- for line in lines:
- new_sent = ''
- if line:
- sentences = split_sentence(line)
- for sentence in sentences:
- copy = sentence.strip()
- if self.lowercase:
- copy = copy.lower()
- if self.remove_regex:
- copy = self.remove_regex.sub('', copy)
+ if len(copy) < self.min_repeat_sentence_length:
+ new_sent += sentence
+ elif copy not in hash_set:
+ new_sent += sentence
+ hash_set.add(copy)
+ new_lines.append(new_sent)
- if len(copy) < self.min_repeat_sentence_length:
- new_sent += sentence
- elif copy not in hash_set:
- new_sent += sentence
- hash_set.add(copy)
- new_lines.append(new_sent)
+ samples[self.text_key][idx] = '\n'.join(new_lines)
- sample[self.text_key] = '\n'.join(new_lines)
- return sample
+ return samples
diff --git a/data_juicer/ops/mapper/remove_specific_chars_mapper.py b/data_juicer/ops/mapper/remove_specific_chars_mapper.py
index 99e15afef..d487efa2f 100644
--- a/data_juicer/ops/mapper/remove_specific_chars_mapper.py
+++ b/data_juicer/ops/mapper/remove_specific_chars_mapper.py
@@ -9,6 +9,8 @@
class RemoveSpecificCharsMapper(Mapper):
"""Mapper to clean specific chars in text samples."""
+ _batched_op = True
+
def __init__(self,
chars_to_remove: Union[str, List[str]] = '◆●■►▼▲▴∆▻▷❖♡□',
*args,
@@ -28,13 +30,14 @@ def __init__(self,
else:
self.pattern = None
- def process(self, sample):
-
+ def process(self, samples):
if self.pattern is None:
- return sample
-
- sample[self.text_key] = re.sub(pattern=self.pattern,
- repl=r'',
- string=sample[self.text_key],
- flags=re.DOTALL)
- return sample
+ return samples
+
+ samples[self.text_key] = list(
+ map(
+ lambda text: re.sub(pattern=self.pattern,
+ repl=r'',
+ string=text,
+ flags=re.DOTALL), samples[self.text_key]))
+ return samples
diff --git a/data_juicer/ops/mapper/remove_table_text_mapper.py b/data_juicer/ops/mapper/remove_table_text_mapper.py
index ca12104c0..8273c8dab 100644
--- a/data_juicer/ops/mapper/remove_table_text_mapper.py
+++ b/data_juicer/ops/mapper/remove_table_text_mapper.py
@@ -14,6 +14,8 @@ class RemoveTableTextMapper(Mapper):
number of tables.
"""
+ _batched_op = True
+
def __init__(self,
min_col: Annotated[int, Field(ge=2, le=20)] = 2,
max_col: Annotated[int, Field(ge=2, le=20)] = 20,
@@ -32,12 +34,12 @@ def __init__(self,
self.max_col = max_col
self.pattern = r'(?<=\n)((\S+?)([ |\t](\S+?)){%d}\n+){2,}'
- def process(self, sample):
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ for idx in range(self.min_col - 1, self.max_col):
+ pattern = re.compile(self.pattern % idx)
+ text = pattern.sub('', text)
- text = sample[self.text_key]
- for i in range(self.min_col - 1, self.max_col):
- pattern = re.compile(self.pattern % i)
- text = pattern.sub('', text)
+ samples[self.text_key][idx] = text
- sample[self.text_key] = text
- return sample
+ return samples
diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py
index d262c1d17..97bb61319 100644
--- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py
+++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py
@@ -18,6 +18,8 @@
class RemoveWordsWithIncorrectSubstringsMapper(Mapper):
"""Mapper to remove words with incorrect substrings."""
+ _batched_op = True
+
def __init__(self,
lang: str = 'en',
tokenization: bool = False,
@@ -48,25 +50,30 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings):
should_keep = all([(i_substr not in word) for i_substr in substrings])
return should_keep
- def process(self, sample):
- if self.tokenization:
- tokenizer = get_model(self.model_key)
- sentences = get_words_from_document(
- sample[self.text_key],
- token_func=tokenizer.encode_as_pieces if tokenizer else None)
- words = [
- word.replace('▁', '') for word in sentences
- if self.should_keep_word_with_incorrect_substrings(
- word.replace('▁', ''), self.substrings)
- ]
- if len(words) != len(sentences):
- sample[self.text_key] = ''.join(words)
- else:
- sentences = split_on_newline_tab_whitespace(sample[self.text_key])
- sentences = [[[
- word for word in subsentence
- if self.should_keep_word_with_incorrect_substrings(
- word, self.substrings)
- ] for subsentence in sentence] for sentence in sentences]
- sample[self.text_key] = merge_on_whitespace_tab_newline(sentences)
- return sample
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ if self.tokenization:
+ tokenizer = get_model(self.model_key)
+ sentences = get_words_from_document(
+ text,
+ token_func=tokenizer.encode_as_pieces
+ if tokenizer else None)
+ words = [
+ word.replace('▁', '') for word in sentences
+ if self.should_keep_word_with_incorrect_substrings(
+ word.replace('▁', ''), self.substrings)
+ ]
+ if len(words) != len(sentences):
+ text = ''.join(words)
+ else:
+ sentences = split_on_newline_tab_whitespace(text)
+ sentences = [[[
+ word for word in subsentence
+ if self.should_keep_word_with_incorrect_substrings(
+ word, self.substrings)
+ ] for subsentence in sentence] for sentence in sentences]
+ text = merge_on_whitespace_tab_newline(sentences)
+
+ samples[self.text_key][idx] = text
+
+ return samples
diff --git a/data_juicer/ops/mapper/replace_content_mapper.py b/data_juicer/ops/mapper/replace_content_mapper.py
index d16e4ec7c..324cc6357 100644
--- a/data_juicer/ops/mapper/replace_content_mapper.py
+++ b/data_juicer/ops/mapper/replace_content_mapper.py
@@ -11,6 +11,8 @@ class ReplaceContentMapper(Mapper):
a specific regular expression pattern with a designated
replacement string."""
+ _batched_op = True
+
def __init__(self,
pattern: Union[str, List[str], None] = None,
repl: Union[str, List[str]] = '',
@@ -42,21 +44,23 @@ def _prepare_pattern(self, pattern: str) -> re.Pattern:
pattern = pattern[2:-1]
return re.compile(pattern, flags=re.DOTALL)
- def process(self, sample):
+ def process(self, samples):
if self.pattern is None:
- return sample
-
- for i, pattern in enumerate(self.compiled_patterns):
- if isinstance(self.repl, list) and i < len(self.repl):
- replacement = self.repl[i]
- elif isinstance(self.repl, list) and i >= len(self.repl):
- raise ValueError(f"pattern length: {len(self.pattern)} '"
- f'must be equal to '
- f'repl length: {len(self.repl)}')
- else:
- replacement = self.repl
-
- sample[self.text_key] = pattern.sub(replacement,
- sample[self.text_key])
-
- return sample
+ return samples
+
+ for idx, text in enumerate(samples[self.text_key]):
+ for i, pattern in enumerate(self.compiled_patterns):
+ if isinstance(self.repl, list) and i < len(self.repl):
+ replacement = self.repl[i]
+ elif isinstance(self.repl, list) and i >= len(self.repl):
+ raise ValueError(f"pattern length: {len(self.pattern)} '"
+ f'must be equal to '
+ f'repl length: {len(self.repl)}')
+ else:
+ replacement = self.repl
+
+ text = pattern.sub(replacement, text)
+
+ samples[self.text_key][idx] = text
+
+ return samples
diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py
index 522c01300..819cfd55c 100644
--- a/data_juicer/ops/mapper/sentence_split_mapper.py
+++ b/data_juicer/ops/mapper/sentence_split_mapper.py
@@ -14,6 +14,8 @@
class SentenceSplitMapper(Mapper):
"""Mapper to split text samples to sentences."""
+ _batched_op = True
+
def __init__(self, lang: str = 'en', *args, **kwargs):
"""
Initialization method.
@@ -26,10 +28,14 @@ def __init__(self, lang: str = 'en', *args, **kwargs):
self.lang = lang
self.model_key = prepare_model(model_type='nltk', lang=lang)
- def process(self, sample):
+ def process(self, samples):
nltk_model = get_model(self.model_key)
- sample[self.text_key] = get_sentences_from_document(
- sample[self.text_key],
- model_func=nltk_model.tokenize if nltk_model else None)
- return sample
+
+ samples[self.text_key] = [
+ get_sentences_from_document(
+ text, model_func=nltk_model.tokenize if nltk_model else None)
+ for text in samples[self.text_key]
+ ]
+
+ return samples
diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py
index 6fa44b559..3102cedab 100644
--- a/data_juicer/ops/mapper/whitespace_normalization_mapper.py
+++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py
@@ -16,6 +16,8 @@ class WhitespaceNormalizationMapper(Mapper):
https://en.wikipedia.org/wiki/Whitespace_character
"""
+ _batched_op = True
+
def __init__(self, *args, **kwargs):
"""
Initialization method.
@@ -25,13 +27,15 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)
- def process(self, sample):
- # remove whitespaces before and after the main content
- text = sample[self.text_key].strip()
+ def process(self, samples):
+ for idx, text in enumerate(samples[self.text_key]):
+ # remove whitespaces before and after the main content
+ text = text.strip()
- # replace all kinds of whitespaces with ' '
- sample[self.text_key] = ''.join([
- char if char not in VARIOUS_WHITESPACES else ' ' for char in text
- ])
+ # replace all kinds of whitespaces with ' '
+ samples[self.text_key][idx] = ''.join([
+ char if char not in VARIOUS_WHITESPACES else ' '
+ for char in text
+ ])
- return sample
+ return samples
diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py
index 54d9a44dc..e7e568f48 100644
--- a/tests/config/test_config_funcs.py
+++ b/tests/config/test_config_funcs.py
@@ -50,6 +50,7 @@ def test_yaml_cfg_file(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
@@ -67,6 +68,7 @@ def test_yaml_cfg_file(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
}, 'nested dict load fail, un-expected internal value')
@@ -132,6 +134,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
})
self.assertDictEqual(
@@ -149,6 +152,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
})
self.assertDictEqual(
@@ -166,6 +170,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
})
self.assertDictEqual(
@@ -183,6 +188,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
})
self.assertDictEqual(
@@ -200,6 +206,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
+ 'batch_size': 1,
}
})
diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py
index d4ea828c0..3d66189d0 100644
--- a/tests/ops/filter/test_alphanumeric_filter.py
+++ b/tests/ops/filter/test_alphanumeric_filter.py
@@ -1,9 +1,6 @@
import unittest
-from data_juicer.core.data import NestedDataset as Dataset
-
from data_juicer.ops.filter.alphanumeric_filter import AlphanumericFilter
-from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG
@@ -39,7 +36,7 @@ def test_case(self):
'text': 'emoji表情测试下😊,😸31231\n'
}]
dataset = self.generate_dataset(ds_list)
- op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9)
+ op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9, batch_size=3, num_proc=1)
result = self.run_single_op(dataset, op, ["text"])
self.assertDatasetEqual(result, tgt_list)
@@ -67,7 +64,7 @@ def test_token_case(self):
'text': 'Do you need a cup of coffee?'
}]
dataset = self.generate_dataset(ds_list)
- op = AlphanumericFilter(tokenization=True, min_ratio=1.5)
+ op = AlphanumericFilter(tokenization=True, min_ratio=1.5, batch_size=2, num_proc=1)
result = self.run_single_op(dataset, op, ["text"])
self.assertDatasetEqual(result, tgt_list)
diff --git a/tests/ops/filter/test_average_line_length_filter.py b/tests/ops/filter/test_average_line_length_filter.py
index e294cb77e..e9768d605 100644
--- a/tests/ops/filter/test_average_line_length_filter.py
+++ b/tests/ops/filter/test_average_line_length_filter.py
@@ -4,50 +4,77 @@
from data_juicer.ops.filter.average_line_length_filter import \
AverageLineLengthFilter
-from data_juicer.utils.constant import Fields
+from data_juicer.utils.constant import Fields, InterVars
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
class AverageLineLengthFilterTest(DataJuicerTestCaseBase):
+ text_key = 'text'
+ ds_list = [{
+ text_key: 'a=1\nb\nc=1+2+3+5\nd=6'
+ }, {
+ text_key:
+ "Today is Sund Sund Sunda and it's a happy day!\nYou know"
+ }, {
+ text_key: 'a v s e e f g a qkc'
+ }, {
+ text_key: ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
+ }, {
+ text_key: 'Do you need a cup of coffee?'
+ }, {
+ text_key: 'emoji表情测试下😊,😸31231\n'
+ }]
+ tgt_list = [{
+ text_key: 'a v s e e f g a qkc'
+ }, {
+ text_key: 'emoji表情测试下😊,😸31231\n'
+ }]
def _run_average_line_length_filter(self, dataset: Dataset, target_list,
- op):
+ op, context=False):
if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
- dataset = dataset.select_columns(column_names=['text'])
- res_list = dataset.to_list()
+ if context:
+ dataset = dataset.add_column(name=Fields.context,
+ column=[{}] * dataset.num_rows)
+ dataset = dataset.map(
+ op.compute_stats,
+ batch_size=op.batch_size,
+ fn_kwargs={'context': context}
+ )
+ dataset = dataset.filter(op.process, batch_size=op.batch_size)
+ dataset_test = dataset.select_columns(column_names=[self.text_key])
+ res_list = dataset_test.to_list()
self.assertEqual(res_list, target_list)
- def test_case(self):
+ return dataset
- ds_list = [{
- 'text': 'a=1\nb\nc=1+2+3+5\nd=6'
- }, {
- 'text':
- "Today is Sund Sund Sunda and it's a happy day!\nYou know"
- }, {
- 'text': 'a v s e e f g a qkc'
- }, {
- 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
- }, {
- 'text': 'Do you need a cup of coffee?'
- }, {
- 'text': 'emoji表情测试下😊,😸31231\n'
- }]
- tgt_list = [{
- 'text': 'a v s e e f g a qkc'
- }, {
- 'text': 'emoji表情测试下😊,😸31231\n'
- }]
- dataset = Dataset.from_list(ds_list)
- op = AverageLineLengthFilter(min_len=10, max_len=20)
- self._run_average_line_length_filter(dataset, tgt_list, op)
+ def test_case_default(self):
+ dataset = Dataset.from_list(self.ds_list)
+ op = AverageLineLengthFilter(min_len=10, max_len=20, batch_size=3)
+ self._run_average_line_length_filter(dataset, self.tgt_list, op, context=False)
+
+ def test_case_context(self):
+ dataset = Dataset.from_list(self.ds_list)
+ op = AverageLineLengthFilter(min_len=10, max_len=20, batch_size=2)
+ dataset = self._run_average_line_length_filter(dataset, self.tgt_list, op, context=True)
+
+ dataset = dataset.select_columns(column_names=[Fields.context])
+ res_list = dataset.to_list()
+
+ tgt_context_list = [
+ {
+ Fields.context: {
+ InterVars.lines: tgt[self.text_key].splitlines()
+ }
+ } for tgt in self.tgt_list
+ ]
+
+ self.assertEqual(res_list, tgt_context_list)
if __name__ == '__main__':
diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py
index 77c1ac1d2..8c4334efd 100644
--- a/tests/ops/filter/test_character_repetition_filter.py
+++ b/tests/ops/filter/test_character_repetition_filter.py
@@ -18,8 +18,8 @@ def _run_character_repetition_filter(self, dataset: Dataset, target_list,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
+ dataset = dataset.map(op.compute_stats, batch_size=op.batch_size, num_proc=1)
+ dataset = dataset.filter(op.process, batch_size=op.batch_size, num_proc=2)
dataset = dataset.select_columns(column_names=['text'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
@@ -42,7 +42,11 @@ def test_case(self):
'text': '中文也是一个字算一个长度'
}]
dataset = Dataset.from_list(ds_list)
- op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=0.4)
+ op = CharacterRepetitionFilter(
+ rep_len=5,
+ min_ratio=0.0,
+ max_ratio=0.4,
+ batch_size=2)
self._run_character_repetition_filter(dataset, tgt_list, op)
diff --git a/tests/ops/filter/test_maximum_line_length_filter.py b/tests/ops/filter/test_maximum_line_length_filter.py
index 6f1cab7f6..6596c0f34 100644
--- a/tests/ops/filter/test_maximum_line_length_filter.py
+++ b/tests/ops/filter/test_maximum_line_length_filter.py
@@ -4,50 +4,77 @@
from data_juicer.ops.filter.maximum_line_length_filter import \
MaximumLineLengthFilter
-from data_juicer.utils.constant import Fields
+from data_juicer.utils.constant import Fields, InterVars
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
class MaximumLineLengthFilterTest(DataJuicerTestCaseBase):
+ text_key = 'text'
+ ds_list = [{
+ text_key: 'a=1\nb\nc=1+2+3+5\nd=6'
+ }, {
+ text_key:
+ "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know"
+ }, {
+ text_key: 'a v s e e f g a qkc'
+ }, {
+ text_key: ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
+ }, {
+ text_key: 'Do you need a cup of coffee?'
+ }, {
+ text_key: 'emoji表情测试下😊,😸31231\n'
+ }]
+ tgt_list = [{
+ text_key: 'a v s e e f g a qkc'
+ }, {
+ text_key: 'emoji表情测试下😊,😸31231\n'
+ }]
def _run_maximum_line_length_filter(self, dataset: Dataset, target_list,
- op):
+ op, context=False):
if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
- dataset = dataset.select_columns(column_names=['text'])
- res_list = dataset.to_list()
+ if context:
+ dataset = dataset.add_column(name=Fields.context,
+ column=[{}] * dataset.num_rows)
+ dataset = dataset.map(
+ op.compute_stats,
+ batch_size=op.batch_size,
+ fn_kwargs={'context': context}
+ )
+ dataset = dataset.filter(op.process, batch_size=op.batch_size)
+ dataset_test = dataset.select_columns(column_names=[self.text_key])
+ res_list = dataset_test.to_list()
self.assertEqual(res_list, target_list)
- def test_case(self):
+ return dataset
- ds_list = [{
- 'text': 'a=1\nb\nc=1+2+3+5\nd=6'
- }, {
- 'text':
- "Today is Sund Sund Sund Sunda and it's a happy day!\nYou know"
- }, {
- 'text': 'a v s e e f g a qkc'
- }, {
- 'text': ',。、„”“«»1」「《》´∶:?!();–—.~’…━〈〉【】%►'
- }, {
- 'text': 'Do you need a cup of coffee?'
- }, {
- 'text': 'emoji表情测试下😊,😸31231\n'
- }]
- tgt_list = [{
- 'text': 'a v s e e f g a qkc'
- }, {
- 'text': 'emoji表情测试下😊,😸31231\n'
- }]
- dataset = Dataset.from_list(ds_list)
- op = MaximumLineLengthFilter(min_len=10, max_len=20)
- self._run_maximum_line_length_filter(dataset, tgt_list, op)
+ def test_case_default(self):
+ dataset = Dataset.from_list(self.ds_list)
+ op = MaximumLineLengthFilter(min_len=10, max_len=20, batch_size=3)
+ self._run_maximum_line_length_filter(dataset, self.tgt_list, op, context=False)
+
+ def test_case_context(self):
+ dataset = Dataset.from_list(self.ds_list)
+ op = MaximumLineLengthFilter(min_len=10, max_len=20, batch_size=2)
+ dataset = self._run_maximum_line_length_filter(dataset, self.tgt_list, op, context=True)
+
+ dataset = dataset.select_columns(column_names=[Fields.context])
+ res_list = dataset.to_list()
+
+ tgt_context_list = [
+ {
+ Fields.context: {
+ InterVars.lines: tgt[self.text_key].splitlines()
+ }
+ } for tgt in self.tgt_list
+ ]
+
+ self.assertEqual(res_list, tgt_context_list)
if __name__ == '__main__':
diff --git a/tests/ops/filter/test_perplexity_filter.py b/tests/ops/filter/test_perplexity_filter.py
index 07e87d17c..e3e582ff6 100644
--- a/tests/ops/filter/test_perplexity_filter.py
+++ b/tests/ops/filter/test_perplexity_filter.py
@@ -9,20 +9,28 @@
class PerplexityFilterTest(DataJuicerTestCaseBase):
- def _run_perplexity_filter(self, dataset: Dataset, target_list, op):
+ def _run_perplexity_filter(self, dataset: Dataset, target_list, op, context=False):
if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
- dataset = dataset.select_columns(column_names=['text'])
- res_list = dataset.to_list()
+ if context:
+ dataset = dataset.add_column(name=Fields.context,
+ column=[{}] * dataset.num_rows)
+ dataset = dataset.map(
+ op.compute_stats,
+ batch_size=op.batch_size,
+ fn_kwargs={'context': context}
+ )
+ dataset = dataset.filter(op.process, batch_size=op.batch_size)
+ dataset_test = dataset.select_columns(column_names=['text'])
+ res_list = dataset_test.to_list()
self.assertEqual(res_list, target_list)
+ return dataset
- def test_en_case(self):
+ def _test_en_case(self, context=False):
ds_list = [{
'text': "Today is Sunday and it's a happy day!"
@@ -44,8 +52,24 @@ def test_en_case(self):
'text': 'Do you need a cup of coffee?'
}]
dataset = Dataset.from_list(ds_list)
- op = PerplexityFilter(lang='en', max_ppl=900)
- self._run_perplexity_filter(dataset, tgt_list, op)
+ op = PerplexityFilter(lang='en', max_ppl=900, batch_size=2)
+ dataset= self._run_perplexity_filter(dataset, tgt_list, op, context)
+ if context:
+ dataset = dataset.select_columns(column_names=[Fields.context])
+ context_list = dataset.to_list()
+ res_words_list = [list(context_list[i][Fields.context].values()) \
+ for i in range(len(context_list))]
+ tgt_words_list = [
+ [['▁Today', '▁is', '▁Sunday', '▁and', '▁it', "'", 's', '▁a', '▁happy', '▁day', '!']],
+ [['▁Do', '▁you', '▁need', '▁a', '▁cup', '▁of', '▁coffee', '?']]
+ ]
+ self.assertListEqual(res_words_list, tgt_words_list)
+
+ def test_en_case_default(self):
+ self._test_en_case(context=False)
+
+ def test_en_case_context(self):
+ self._test_en_case(context=True)
if __name__ == '__main__':
diff --git a/tests/ops/filter/test_special_characters_filter.py b/tests/ops/filter/test_special_characters_filter.py
index b1dd8632e..b2e8292f1 100644
--- a/tests/ops/filter/test_special_characters_filter.py
+++ b/tests/ops/filter/test_special_characters_filter.py
@@ -18,8 +18,8 @@ def _run_special_characters_filter(self, dataset: Dataset, target_list,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
+ dataset = dataset.map(op.compute_stats, batch_size=op.batch_size)
+ dataset = dataset.filter(op.process, batch_size=op.batch_size)
dataset = dataset.select_columns(column_names=['text'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
@@ -49,7 +49,7 @@ def test_case(self):
'text': 'Do you need a cup of coffee?'
}]
dataset = Dataset.from_list(ds_list)
- op = SpecialCharactersFilter(min_ratio=0.0, max_ratio=0.25)
+ op = SpecialCharactersFilter(min_ratio=0.0, max_ratio=0.25, batch_size=2)
self._run_special_characters_filter(dataset, tgt_list, op)
diff --git a/tests/ops/filter/test_text_length_filter.py b/tests/ops/filter/test_text_length_filter.py
index 67efb6c60..2a653ff42 100644
--- a/tests/ops/filter/test_text_length_filter.py
+++ b/tests/ops/filter/test_text_length_filter.py
@@ -16,8 +16,8 @@ def _run_text_length_filter(self, dataset: Dataset, target_list, op):
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
+ dataset = dataset.map(op.compute_stats, batch_size=3)
+ dataset = dataset.filter(op.process, batch_size=2)
dataset = dataset.select_columns(column_names=['text'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py
index 0d53a164d..f099bec05 100644
--- a/tests/ops/filter/test_word_num_filter.py
+++ b/tests/ops/filter/test_word_num_filter.py
@@ -16,8 +16,8 @@ def _run_words_num_filter(self, dataset: Dataset, target_list, op):
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
+ dataset = dataset.map(op.compute_stats, batch_size=op.batch_size)
+ dataset = dataset.filter(op.process, batch_size=op.batch_size)
dataset = dataset.select_columns(column_names=['text'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
@@ -41,7 +41,7 @@ def test_case(self):
'text': 'a v s e c s f e f g a a a '
}]
dataset = Dataset.from_list(ds_list)
- op = WordsNumFilter(min_num=5, max_num=15)
+ op = WordsNumFilter(min_num=5, max_num=15, batch_size=2)
self._run_words_num_filter(dataset, tgt_list, op)
def test_zh_case(self):
@@ -68,7 +68,8 @@ def test_zh_case(self):
op = WordsNumFilter(lang='zh',
tokenization=True,
min_num=10,
- max_num=25)
+ max_num=25,
+ batch_size=1)
self._run_words_num_filter(dataset, tgt_list, op)
diff --git a/tests/ops/filter/test_word_repetition_filter.py b/tests/ops/filter/test_word_repetition_filter.py
index f59576ef8..030189d56 100644
--- a/tests/ops/filter/test_word_repetition_filter.py
+++ b/tests/ops/filter/test_word_repetition_filter.py
@@ -16,8 +16,8 @@ def _run_word_repetition_filter(self, dataset: Dataset, target_list, op):
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
- dataset = dataset.map(op.compute_stats)
- dataset = dataset.filter(op.process)
+ dataset = dataset.map(op.compute_stats, batch_size=op.batch_size)
+ dataset = dataset.filter(op.process, batch_size=op.batch_size)
dataset = dataset.select_columns(column_names=['text'])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
@@ -51,7 +51,11 @@ def test_en_case(self):
'This proposed a novel proposed pretraining proposed pretraining.'
}]
dataset = Dataset.from_list(ds_list)
- op = WordRepetitionFilter(rep_len=3, min_ratio=0.0, max_ratio=0.2)
+ op = WordRepetitionFilter(
+ rep_len=3,
+ min_ratio=0.0,
+ max_ratio=0.2,
+ batch_size=2)
self._run_word_repetition_filter(dataset, tgt_list, op)
def test_zh_case(self):
@@ -79,7 +83,8 @@ def test_zh_case(self):
tokenization=True,
rep_len=3,
min_ratio=0.0,
- max_ratio=0.2)
+ max_ratio=0.2,
+ batch_size=1)
self._run_word_repetition_filter(dataset, tgt_list, op)
diff --git a/tests/ops/mapper/test_chinese_convert_mapper.py b/tests/ops/mapper/test_chinese_convert_mapper.py
index 9bbe8e8df..bc21f40fe 100644
--- a/tests/ops/mapper/test_chinese_convert_mapper.py
+++ b/tests/ops/mapper/test_chinese_convert_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.chinese_convert_mapper import ChineseConvertMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self, mode='s2t'):
self.op = ChineseConvertMapper(mode)
def _run_chinese_convert(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_s2t(self):
diff --git a/tests/ops/mapper/test_clean_copyright_mapper.py b/tests/ops/mapper/test_clean_copyright_mapper.py
index 726d829f7..a236988f7 100644
--- a/tests/ops/mapper/test_clean_copyright_mapper.py
+++ b/tests/ops/mapper/test_clean_copyright_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.clean_copyright_mapper import CleanCopyrightMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self):
self.op = CleanCopyrightMapper()
def _run_clean_copyright(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_clean_copyright(self):
diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py
index b3f0e5e9a..1ff7e389e 100644
--- a/tests/ops/mapper/test_clean_email_mapper.py
+++ b/tests/ops/mapper/test_clean_email_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.clean_email_mapper import CleanEmailMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -7,9 +8,11 @@
class CleanEmailMapperTest(DataJuicerTestCaseBase):
def _run_clean_email(self, op, samples):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_clean_email(self):
diff --git a/tests/ops/mapper/test_clean_html_mapper.py b/tests/ops/mapper/test_clean_html_mapper.py
index 69249b60a..71d4e11ee 100644
--- a/tests/ops/mapper/test_clean_html_mapper.py
+++ b/tests/ops/mapper/test_clean_html_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.clean_html_mapper import CleanHtmlMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self):
self.op = CleanHtmlMapper()
def _run_helper(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_complete_html_text(self):
diff --git a/tests/ops/mapper/test_clean_ip_mapper.py b/tests/ops/mapper/test_clean_ip_mapper.py
index ccbaf52b7..479228263 100644
--- a/tests/ops/mapper/test_clean_ip_mapper.py
+++ b/tests/ops/mapper/test_clean_ip_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.clean_ip_mapper import CleanIpMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -7,9 +8,11 @@
class CleanIpMapperTest(DataJuicerTestCaseBase):
def _run_clean_ip(self, op, samples):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_ipv4(self):
diff --git a/tests/ops/mapper/test_clean_links_mapper.py b/tests/ops/mapper/test_clean_links_mapper.py
index 28e14b2d9..5efcd4acd 100644
--- a/tests/ops/mapper/test_clean_links_mapper.py
+++ b/tests/ops/mapper/test_clean_links_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.clean_links_mapper import CleanLinksMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self):
self.op = CleanLinksMapper()
def _run_clean_links(self, op, samples):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_lower_ftp_links_text(self):
diff --git a/tests/ops/mapper/test_expand_macro_mapper.py b/tests/ops/mapper/test_expand_macro_mapper.py
index 68dbf047b..bdc758193 100644
--- a/tests/ops/mapper/test_expand_macro_mapper.py
+++ b/tests/ops/mapper/test_expand_macro_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.expand_macro_mapper import ExpandMacroMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self):
self.op = ExpandMacroMapper()
def _run_expand_macro(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_case(self):
diff --git a/tests/ops/mapper/test_fix_unicode_mapper.py b/tests/ops/mapper/test_fix_unicode_mapper.py
index 547020b51..1e969d117 100644
--- a/tests/ops/mapper/test_fix_unicode_mapper.py
+++ b/tests/ops/mapper/test_fix_unicode_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.fix_unicode_mapper import FixUnicodeMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self):
self.op = FixUnicodeMapper()
def _run_fix_unicode(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_bad_unicode_text(self):
diff --git a/tests/ops/mapper/test_punctuation_normalization_mapper.py b/tests/ops/mapper/test_punctuation_normalization_mapper.py
index a69d4040e..080666ce8 100644
--- a/tests/ops/mapper/test_punctuation_normalization_mapper.py
+++ b/tests/ops/mapper/test_punctuation_normalization_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.punctuation_normalization_mapper import \
PunctuationNormalizationMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -11,9 +12,11 @@ def setUp(self):
self.op = PunctuationNormalizationMapper()
def _run_punctuation_normalization(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_case(self):
diff --git a/tests/ops/mapper/test_remove_bibliography_mapper.py b/tests/ops/mapper/test_remove_bibliography_mapper.py
index 76096fe93..9d08c2a4d 100644
--- a/tests/ops/mapper/test_remove_bibliography_mapper.py
+++ b/tests/ops/mapper/test_remove_bibliography_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_bibliography_mapper import \
RemoveBibliographyMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -11,9 +12,11 @@ def setUp(self):
self.op = RemoveBibliographyMapper()
def _run_remove_bibliography(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_bibliography_case(self):
diff --git a/tests/ops/mapper/test_remove_comments_mapper.py b/tests/ops/mapper/test_remove_comments_mapper.py
index 81a0df5de..93a287460 100644
--- a/tests/ops/mapper/test_remove_comments_mapper.py
+++ b/tests/ops/mapper/test_remove_comments_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_comments_mapper import RemoveCommentsMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -7,9 +8,11 @@
class RemoveCommentsMapperTest(DataJuicerTestCaseBase):
def _run_remove_comments(self, samples, op):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_tex_case(self):
diff --git a/tests/ops/mapper/test_remove_header_mapper.py b/tests/ops/mapper/test_remove_header_mapper.py
index c91bfe790..0196b0317 100644
--- a/tests/ops/mapper/test_remove_header_mapper.py
+++ b/tests/ops/mapper/test_remove_header_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_header_mapper import RemoveHeaderMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@ def setUp(self):
self.op = RemoveHeaderMapper()
def _run_remove_header(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_case(self):
diff --git a/tests/ops/mapper/test_remove_long_words_mapper.py b/tests/ops/mapper/test_remove_long_words_mapper.py
index 533d7a717..817043979 100644
--- a/tests/ops/mapper/test_remove_long_words_mapper.py
+++ b/tests/ops/mapper/test_remove_long_words_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_long_words_mapper import \
RemoveLongWordsMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -8,9 +9,11 @@
class RemoveLongWordsMapperTest(DataJuicerTestCaseBase):
def _run_remove_long_words(self, samples, op):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_normal_case(self):
diff --git a/tests/ops/mapper/test_remove_non_chinese_character_mapper.py b/tests/ops/mapper/test_remove_non_chinese_character_mapper.py
index 283a75ab0..1ab000107 100644
--- a/tests/ops/mapper/test_remove_non_chinese_character_mapper.py
+++ b/tests/ops/mapper/test_remove_non_chinese_character_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_non_chinese_character_mapper import \
RemoveNonChineseCharacterlMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -12,9 +13,11 @@ def setUp(self, keep_alphabet=True, keep_number=True, keep_punc=True):
keep_punc)
def _run_remove_non_chinese_character(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_remove_non_chinese_character(self):
diff --git a/tests/ops/mapper/test_remove_repeat_sentences_mapper.py b/tests/ops/mapper/test_remove_repeat_sentences_mapper.py
index a7fe347fe..a2a560f97 100644
--- a/tests/ops/mapper/test_remove_repeat_sentences_mapper.py
+++ b/tests/ops/mapper/test_remove_repeat_sentences_mapper.py
@@ -2,6 +2,7 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_repeat_sentences_mapper import \
RemoveRepeatSentencesMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -10,9 +11,11 @@
class RemoveRepeatSentencesMapperTest(DataJuicerTestCaseBase):
def _run_helper(self, samples, op):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_text(self):
diff --git a/tests/ops/mapper/test_remove_specific_chars_mapper.py b/tests/ops/mapper/test_remove_specific_chars_mapper.py
index f61a3f6fc..f786db3ee 100644
--- a/tests/ops/mapper/test_remove_specific_chars_mapper.py
+++ b/tests/ops/mapper/test_remove_specific_chars_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_specific_chars_mapper import \
RemoveSpecificCharsMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -11,9 +12,11 @@ def setUp(self):
self.op = RemoveSpecificCharsMapper()
def _run_helper(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_complete_html_text(self):
diff --git a/tests/ops/mapper/test_remove_table_text_mapper.py b/tests/ops/mapper/test_remove_table_text_mapper.py
index 2be4a2453..214a36e79 100644
--- a/tests/ops/mapper/test_remove_table_text_mapper.py
+++ b/tests/ops/mapper/test_remove_table_text_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_table_text_mapper import \
RemoveTableTextMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -11,9 +12,11 @@ def setUp(self):
self.op = RemoveTableTextMapper()
def _run_remove_header(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_single_table_case(self):
diff --git a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py
index 02157ad52..f6d6d109f 100644
--- a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py
+++ b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.remove_words_with_incorrect_substrings_mapper import \
RemoveWordsWithIncorrectSubstringsMapper # noqa: E501
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -8,9 +9,11 @@
class RemoveWordsWithIncorrectSubstringsMapperTest(DataJuicerTestCaseBase):
def _run_remove_words_with_incorrect_sbstrings(self, samples, op):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_en_case(self):
diff --git a/tests/ops/mapper/test_replace_content_mapper.py b/tests/ops/mapper/test_replace_content_mapper.py
index 64f88c888..23ddb3453 100644
--- a/tests/ops/mapper/test_replace_content_mapper.py
+++ b/tests/ops/mapper/test_replace_content_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.replace_content_mapper import ReplaceContentMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -7,9 +8,11 @@
class ReplaceContentMapperTest(DataJuicerTestCaseBase):
def _run_helper(self, op, samples):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_special_char_pattern_text(self):
diff --git a/tests/ops/mapper/test_sentence_split_mapper.py b/tests/ops/mapper/test_sentence_split_mapper.py
index 3cdf3a977..4352fdce8 100644
--- a/tests/ops/mapper/test_sentence_split_mapper.py
+++ b/tests/ops/mapper/test_sentence_split_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.sentence_split_mapper import SentenceSplitMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -7,9 +8,11 @@
class SentenceSplitMapperTest(DataJuicerTestCaseBase):
def _run_helper(self, op, samples):
- for sample in samples:
- result = op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_en_text(self):
diff --git a/tests/ops/mapper/test_whitespace_normalization_mapper.py b/tests/ops/mapper/test_whitespace_normalization_mapper.py
index 985cc7076..c92516ff7 100644
--- a/tests/ops/mapper/test_whitespace_normalization_mapper.py
+++ b/tests/ops/mapper/test_whitespace_normalization_mapper.py
@@ -1,5 +1,6 @@
import unittest
+from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.whitespace_normalization_mapper import \
WhitespaceNormalizationMapper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
@@ -11,9 +12,11 @@ def setUp(self):
self.op = WhitespaceNormalizationMapper()
def _run_whitespace_normalization(self, samples):
- for sample in samples:
- result = self.op.process(sample)
- self.assertEqual(result['text'], result['target'])
+ dataset = Dataset.from_list(samples)
+ dataset = dataset.map(self.op.process, batch_size=2)
+
+ for data in dataset:
+ self.assertEqual(data['text'], data['target'])
def test_case(self):