Skip to content

Commit

Permalink
refine done (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang authored Sep 24, 2024
1 parent 18f2248 commit 467cb96
Show file tree
Hide file tree
Showing 18 changed files with 75 additions and 78 deletions.
12 changes: 11 additions & 1 deletion data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get('image_key', 'images')
self.audio_key = kwargs.get('audio_key', 'audios')
self.video_key = kwargs.get('video_key', 'videos')
self.batch_size = kwargs.get('batch_size', 1)
self.batch_size = kwargs.get('batch_size', 1000)

# whether the model can be accelerated using cuda
_accelerator = kwargs.get('accelerator', None)
Expand Down Expand Up @@ -204,6 +204,12 @@ def add_parameters(self, init_parameter_dict, **extra_param_dict):
related_parameters.update(extra_param_dict)
return related_parameters

def run(self, dataset):
from data_juicer.core.data import NestedDataset
if not isinstance(dataset, NestedDataset):
dataset = NestedDataset(dataset)
return dataset


class Mapper(OP):

Expand Down Expand Up @@ -238,6 +244,7 @@ def process(self, sample):
raise NotImplementedError

def run(self, dataset, *, exporter=None, tracer=None):
dataset = super(Mapper, self).run(dataset)
new_dataset = dataset.map(
self.process,
num_proc=self.runtime_np(),
Expand Down Expand Up @@ -298,6 +305,7 @@ def process(self, sample):
raise NotImplementedError

def run(self, dataset, *, exporter=None, tracer=None):
dataset = super(Filter, self).run(dataset)
if Fields.stats not in dataset.features:
from data_juicer.core.data import add_same_content_to_new_column
dataset = dataset.map(add_same_content_to_new_column,
Expand Down Expand Up @@ -368,6 +376,7 @@ def process(self, dataset, show_num=0):
raise NotImplementedError

def run(self, dataset, *, exporter=None, tracer=None):
dataset = super(Deduplicator, self).run(dataset)
dataset = dataset.map(self.compute_hash,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
Expand Down Expand Up @@ -406,6 +415,7 @@ def process(self, dataset):
raise NotImplementedError

def run(self, dataset, *, exporter=None, tracer=None):
dataset = super(Selector, self).run(dataset)
new_dataset = self.process(dataset)
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
Expand Down
7 changes: 3 additions & 4 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,9 @@ def process(self, samples):
ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \
else StatsKeys.alnum_ratio
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_ratio <= stat[ratio_key] <= self.
max_ratio, samples[Fields.stats]))
return map(
lambda stat: self.min_ratio <= stat[ratio_key] <= self.
max_ratio, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_ratio <= samples[
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,9 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_len <= stat[StatsKeys.avg_line_length
] <= self.max_len,
samples[Fields.stats]))
return map(
lambda stat: self.min_len <= stat[StatsKeys.avg_line_length] <=
self.max_len, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,9 @@ def compute_stats(self, samples):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_ratio <= stat[
StatsKeys.char_rep_ratio] <= self.max_ratio,
samples[Fields.stats]))
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.char_rep_ratio]
<= self.max_ratio, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_ratio <= samples[Fields.stats][
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/maximum_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_len <= stat[StatsKeys.max_line_length
] <= self.max_len,
samples[Fields.stats]))
return map(
lambda stat: self.min_len <= stat[StatsKeys.max_line_length] <=
self.max_len, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl,
samples[Fields.stats]))
return map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl,
samples[Fields.stats])
else:
return samples[Fields.stats][StatsKeys.perplexity] <= self.max_ppl
9 changes: 4 additions & 5 deletions data_juicer/ops/filter/special_characters_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def compute_stats(self, samples):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_ratio <= stat[
StatsKeys.special_char_ratio] <= self.max_ratio,
samples[Fields.stats]))
return map(
lambda stat: self.min_ratio <= stat[
StatsKeys.special_char_ratio] <= self.max_ratio,
samples[Fields.stats])
else:
# single sample for ray filter
if self.min_ratio <= \
Expand Down
7 changes: 3 additions & 4 deletions data_juicer/ops/filter/text_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ def compute_stats(self, samples):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_len <= stat[StatsKeys.text_len] <=
self.max_len, samples[Fields.stats]))
return map(
lambda stat: self.min_len <= stat[StatsKeys.text_len] <= self.
max_len, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
Expand Down
8 changes: 3 additions & 5 deletions data_juicer/ops/filter/word_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,9 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_ratio <= stat[
StatsKeys.word_rep_ratio] <= self.max_ratio,
samples[Fields.stats]))
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio]
<= self.max_ratio, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_ratio <= samples[Fields.stats][
Expand Down
7 changes: 3 additions & 4 deletions data_juicer/ops/filter/words_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return list(
map(
lambda stat: self.min_num <= stat[StatsKeys.num_words] <=
self.max_num, samples[Fields.stats]))
return map(
lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self.
max_num, samples[Fields.stats])
else:
# single sample for ray filter
if self.min_num <= samples[Fields.stats][
Expand Down
6 changes: 3 additions & 3 deletions data_juicer/ops/mapper/chinese_convert_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, mode: str = 's2t', *args, **kwargs):
def process(self, samples):
prepare_converter(self.mode)

samples[self.text_key] = list(
map(lambda text: OPENCC_CONVERTER.convert(text),
samples[self.text_key]))
samples[self.text_key] = [
OPENCC_CONVERTER.convert(text) for text in samples[self.text_key]
]
return samples
7 changes: 4 additions & 3 deletions data_juicer/ops/mapper/clean_copyright_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def _process_single_sample(self, sample):
return sample

def process(self, samples):
samples[self.text_key] = list(
map(lambda text: self._process_single_sample(text),
samples[self.text_key]))
samples[self.text_key] = [
self._process_single_sample(text)
for text in samples[self.text_key]
]
return samples
5 changes: 3 additions & 2 deletions data_juicer/ops/mapper/clean_html_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _clean_html(raw_html):
parser = HTMLParser(raw_html)
return parser.text()

samples[self.text_key] = list(
map(lambda text: _clean_html(text), samples[self.text_key]))
samples[self.text_key] = [
_clean_html(text) for text in samples[self.text_key]
]
return samples
9 changes: 4 additions & 5 deletions data_juicer/ops/mapper/fix_unicode_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def __init__(self, normalization: str = None, *args, **kwargs):
'["NFC", "NFKC", "NFD", "NFKD"]')

def process(self, samples):
samples[self.text_key] = list(
map(
lambda text: ftfy.fix_text(text,
normalization=self.normalization),
samples[self.text_key]))
samples[self.text_key] = [
ftfy.fix_text(text, normalization=self.normalization)
for text in samples[self.text_key]
]
return samples
9 changes: 4 additions & 5 deletions data_juicer/ops/mapper/punctuation_normalization_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ def __init__(self, *args, **kwargs):
}

def process(self, samples):
samples[self.text_key] = list(
map(
lambda text: ''.join(
[self.punctuation_unicode.get(c, c) for c in text]),
samples[self.text_key]))
samples[self.text_key] = [
''.join([self.punctuation_unicode.get(c, c) for c in text])
for text in samples[self.text_key]
]
return samples
12 changes: 6 additions & 6 deletions data_juicer/ops/mapper/remove_bibliography_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ def __init__(self, *args, **kwargs):
self.pattern += r').*$'

def process(self, samples):
samples[self.text_key] = list(
map(
lambda text: re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL), samples[self.text_key]))
samples[self.text_key] = [
re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL) for text in samples[self.text_key]
]

return samples
12 changes: 6 additions & 6 deletions data_juicer/ops/mapper/remove_specific_chars_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def process(self, samples):
if self.pattern is None:
return samples

samples[self.text_key] = list(
map(
lambda text: re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL), samples[self.text_key]))
samples[self.text_key] = [
re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL) for text in samples[self.text_key]
]
return samples
14 changes: 7 additions & 7 deletions tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_yaml_cfg_file(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
Expand All @@ -68,7 +68,7 @@ def test_yaml_cfg_file(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
}, 'nested dict load fail, un-expected internal value')

Expand Down Expand Up @@ -134,7 +134,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
})
self.assertDictEqual(
Expand All @@ -152,7 +152,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
})
self.assertDictEqual(
Expand All @@ -170,7 +170,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
})
self.assertDictEqual(
Expand All @@ -188,7 +188,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
})
self.assertDictEqual(
Expand All @@ -206,7 +206,7 @@ def test_mixture_cfg(self):
'cpu_required': 1,
'mem_required': 0,
'turbo': False,
'batch_size': 1,
'batch_size': 1000,
}
})

Expand Down

0 comments on commit 467cb96

Please sign in to comment.