Skip to content

Commit

Permalink
turn map to list
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Sep 19, 2024
1 parent 58f1851 commit 6072c18
Show file tree
Hide file tree
Showing 15 changed files with 66 additions and 47 deletions.
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ def process(self, samples):
ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \
else StatsKeys.alnum_ratio
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[ratio_key] <= self.
max_ratio, samples[Fields.stats])
return list(
map(
lambda stat: self.min_ratio <= stat[ratio_key] <= self.
max_ratio, samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_ratio <= samples[
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_len <= stat[StatsKeys.avg_line_length] <=
self.max_len, samples[Fields.stats])
return list(
map(
lambda stat: self.min_len <= stat[StatsKeys.avg_line_length
] <= self.max_len,
samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def compute_stats(self, samples):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.char_rep_ratio]
<= self.max_ratio, samples[Fields.stats])
return list(
map(
lambda stat: self.min_ratio <= stat[
StatsKeys.char_rep_ratio] <= self.max_ratio,
samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_ratio <= samples[Fields.stats][
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/filter/maximum_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_len <= stat[StatsKeys.max_line_length] <=
self.max_len, samples[Fields.stats])
return list(
map(
lambda stat: self.min_len <= stat[StatsKeys.max_line_length
] <= self.max_len,
samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl,
samples[Fields.stats])
return list(
map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl,
samples[Fields.stats]))
else:
return samples[Fields.stats][StatsKeys.perplexity] <= self.max_ppl
9 changes: 5 additions & 4 deletions data_juicer/ops/filter/special_characters_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ def compute_stats(self, samples):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[
StatsKeys.special_char_ratio] <= self.max_ratio,
samples[Fields.stats])
return list(
map(
lambda stat: self.min_ratio <= stat[
StatsKeys.special_char_ratio] <= self.max_ratio,
samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_ratio <= \
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/filter/word_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio]
<= self.max_ratio, samples[Fields.stats])
return list(
map(
lambda stat: self.min_ratio <= stat[
StatsKeys.word_rep_ratio] <= self.max_ratio,
samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_ratio <= samples[Fields.stats][
Expand Down
7 changes: 4 additions & 3 deletions data_juicer/ops/filter/words_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,10 @@ def compute_stats(self, samples, context=False):

def process(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self.
max_num, samples[Fields.stats])
return list(
map(
lambda stat: self.min_num <= stat[StatsKeys.num_words] <=
self.max_num, samples[Fields.stats]))
else:
# single sample for ray filter
if self.min_num <= samples[Fields.stats][
Expand Down
6 changes: 3 additions & 3 deletions data_juicer/ops/mapper/chinese_convert_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, mode: str = 's2t', *args, **kwargs):
def process(self, samples):
prepare_converter(self.mode)

samples[self.text_key] = map(
lambda text: OPENCC_CONVERTER.convert(text),
samples[self.text_key])
samples[self.text_key] = list(
map(lambda text: OPENCC_CONVERTER.convert(text),
samples[self.text_key]))
return samples
6 changes: 3 additions & 3 deletions data_juicer/ops/mapper/clean_copyright_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _process_single_sample(self, sample):
return sample

def process(self, samples):
samples[self.text_key] = map(
lambda text: self._process_single_sample(text),
samples[self.text_key])
samples[self.text_key] = list(
map(lambda text: self._process_single_sample(text),
samples[self.text_key]))
return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/clean_html_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ def _clean_html(raw_html):
parser = HTMLParser(raw_html)
return parser.text()

samples[self.text_key] = map(lambda text: _clean_html(text),
samples[self.text_key])
samples[self.text_key] = list(
map(lambda text: _clean_html(text), samples[self.text_key]))
return samples
8 changes: 5 additions & 3 deletions data_juicer/ops/mapper/fix_unicode_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(self, normalization: str = None, *args, **kwargs):
'["NFC", "NFKC", "NFD", "NFKD"]')

def process(self, samples):
samples[self.text_key] = map(
lambda text: ftfy.fix_text(text, normalization=self.normalization),
samples[self.text_key])
samples[self.text_key] = list(
map(
lambda text: ftfy.fix_text(text,
normalization=self.normalization),
samples[self.text_key]))
return samples
9 changes: 5 additions & 4 deletions data_juicer/ops/mapper/punctuation_normalization_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def __init__(self, *args, **kwargs):
}

def process(self, samples):
samples[self.text_key] = map(
lambda text: ''.join(
[self.punctuation_unicode.get(c, c) for c in text]),
samples[self.text_key])
samples[self.text_key] = list(
map(
lambda text: ''.join(
[self.punctuation_unicode.get(c, c) for c in text]),
samples[self.text_key]))
return samples
10 changes: 6 additions & 4 deletions data_juicer/ops/mapper/remove_bibliography_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def __init__(self, *args, **kwargs):
self.pattern += r').*$'

def process(self, samples):
samples[self.text_key] = map(
lambda text: re.sub(
pattern=self.pattern, repl=r'', string=text, flags=re.DOTALL),
samples[self.text_key])
samples[self.text_key] = list(
map(
lambda text: re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL), samples[self.text_key]))

return samples
10 changes: 6 additions & 4 deletions data_juicer/ops/mapper/remove_specific_chars_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def process(self, samples):
if self.pattern is None:
return samples

samples[self.text_key] = map(
lambda text: re.sub(
pattern=self.pattern, repl=r'', string=text, flags=re.DOTALL),
samples[self.text_key])
samples[self.text_key] = list(
map(
lambda text: re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL), samples[self.text_key]))
return samples

0 comments on commit 6072c18

Please sign in to comment.