Skip to content

Commit

Permalink
* restore to batched version and rename to xxx_batched
Browse files Browse the repository at this point in the history
  • Loading branch information
HYLcool committed Oct 15, 2024
1 parent b98f6d9 commit c2188d5
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 97 deletions.
36 changes: 23 additions & 13 deletions data_juicer/ops/filter/text_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,27 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len

def compute_stats(self, sample):
# check if it's computed already
if StatsKeys.text_len in sample[Fields.stats]:
return sample

sample[Fields.stats][StatsKeys.text_len] = len(sample[self.text_key])
return sample

def process(self, sample):
if self.min_len <= sample[Fields.stats][
StatsKeys.text_len] <= self.max_len:
return True
def compute_stats_batched(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
for i, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.text_len in stat:
continue
else:
samples_stats[i][StatsKeys.text_len] = len(samples_list[i])

return samples

def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_len <= stat[StatsKeys.text_len] <= self.
max_len, samples[Fields.stats])
else:
return False
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
StatsKeys.text_len] <= self.max_len:
return True
else:
return False
117 changes: 64 additions & 53 deletions data_juicer/ops/filter/word_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,57 +58,68 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.word_rep_ratio in sample[Fields.stats]:
return sample

# try to get words from context
words_key = f'{InterVars.words}-{self.model_key}'
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
def compute_stats_batched(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

for idx, stat in enumerate(samples_stats):
words_key = f'{InterVars.words}-{self.model_key}-{idx}'
# check if it's computed already
if StatsKeys.word_rep_ratio in stat:
continue
# try to get words from context
if context and words_key in samples[Fields.context]:
words = samples[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
if tokenizer else None)
if context:
samples[Fields.context][words_key] = words

# try to get refined words from context
refined_words_key = f'{InterVars.refined_words}-' \
f'True-SPECIAL_CHARS-False-[2]-{idx}'
if context and refined_words_key in samples[Fields.context]:
words = samples[Fields.context][refined_words_key]
else:
words = words_refinement(words,
lower_case=True,
strip_chars=SPECIAL_CHARACTERS)
if context:
samples[Fields.context][refined_words_key] = words
word_ngrams = [
' '.join(words[i:i + self.n])
for i in range(len(words) - self.n + 1)
]
freq_word_ngrams = {}
for word_ngram in word_ngrams:
freq_word_ngrams[word_ngram] = (
freq_word_ngrams.get(word_ngram, 0) + 1)

if len(freq_word_ngrams) == 0:
samples_stats[idx][StatsKeys.word_rep_ratio] = 0.0
continue

freq_word_ngrams = list(freq_word_ngrams.values())
rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1]
samples_stats[idx][StatsKeys.word_rep_ratio] = (
sum(rep_more_than_one) /
sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0

return samples

def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio]
<= self.max_ratio, samples[Fields.stats])
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
if context:
sample[Fields.context][words_key] = words

# try to get refined words from context
refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \
f'False-[2]-'
if context and refined_words_key in sample[Fields.context]:
words = sample[Fields.context][refined_words_key]
else:
words = words_refinement(words,
lower_case=True,
strip_chars=SPECIAL_CHARACTERS)
if context:
sample[Fields.context][refined_words_key] = words
word_ngrams = [
' '.join(words[i:i + self.n])
for i in range(len(words) - self.n + 1)
]
freq_word_ngrams = {}
for word_ngram in word_ngrams:
freq_word_ngrams[word_ngram] = (
freq_word_ngrams.get(word_ngram, 0) + 1)

if len(freq_word_ngrams) == 0:
sample[Fields.stats][StatsKeys.word_rep_ratio] = 0.0
return sample

freq_word_ngrams = list(freq_word_ngrams.values())
rep_more_than_one = [freq for freq in freq_word_ngrams if freq > 1]
sample[Fields.stats][StatsKeys.word_rep_ratio] = (
sum(rep_more_than_one) /
sum(freq_word_ngrams)) if sum(freq_word_ngrams) != 0 else 0.0
return sample

def process(self, sample):
if self.min_ratio <= sample[Fields.stats][StatsKeys.word_rep_ratio] \
<= self.max_ratio:
return True
else:
return False
# single sample for ray filter
if self.min_ratio <= samples[Fields.stats][
StatsKeys.word_rep_ratio] <= self.max_ratio:
return True
else:
return False
55 changes: 33 additions & 22 deletions data_juicer/ops/filter/words_num_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,39 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.num_words in sample[Fields.stats]:
return sample
def compute_stats_batched(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

words_key = f'{InterVars.words}-{self.model_key}'
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
if context:
sample[Fields.context][words_key] = words
words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS)
sample[Fields.stats][StatsKeys.num_words] = len(words)
return sample
for idx, stat in enumerate(samples_stats):
words_key = f'{InterVars.words}-{self.model_key}-{idx}'
# check if it's computed already
if StatsKeys.num_words in stat:
continue
if context and words_key in samples[Fields.context]:
words = samples[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
if tokenizer else None)
if context:
samples[Fields.context][words_key] = words
words = words_refinement(words, strip_chars=SPECIAL_CHARACTERS)
samples_stats[idx][StatsKeys.num_words] = len(words)

return samples

def process(self, sample):
if self.min_num <= sample[Fields.stats][
StatsKeys.num_words] <= self.max_num:
return True
def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self.
max_num, samples[Fields.stats])
else:
return False
# single sample for ray filter
if self.min_num <= samples[Fields.stats][
StatsKeys.num_words] <= self.max_num:
return True
else:
return False
21 changes: 12 additions & 9 deletions data_juicer/ops/mapper/whitespace_normalization_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

def process(self, sample):
text = sample[self.text_key].strip()

# replace all kinds of whitespaces with ' '
sample[self.text_key] = ''.join([
char if char not in VARIOUS_WHITESPACES else ' ' for char in text
])

return sample
def process_batched(self, samples):
for idx, text in enumerate(samples[self.text_key]):
# remove whitespaces before and after the main content
text = text.strip()

# replace all kinds of whitespaces with ' '
samples[self.text_key][idx] = ''.join([
char if char not in VARIOUS_WHITESPACES else ' '
for char in text
])

return samples

0 comments on commit c2188d5

Please sign in to comment.