Skip to content

Commit

Permalink
unify batch index id
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Sep 3, 2024
1 parent 947879f commit bc7ce9d
Show file tree
Hide file tree
Showing 18 changed files with 74 additions and 73 deletions.
18 changes: 8 additions & 10 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,27 @@ def compute_stats(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

for i, stat in enumerate(samples_stats):
for idx, stat in enumerate(samples_stats):
cur_text = samples_list[idx]
if self.tokenization:
if StatsKeys.alpha_token_ratio in stat:
continue
alpha_count = sum(
map(lambda char: 1
if char.isalpha() else 0, samples_list[i]))
map(lambda char: 1 if char.isalpha() else 0, cur_text))
tokenizer = get_model(self.model_key)
token_count = len(
get_words_from_document(
samples_list[i],
cur_text,
token_func=tokenizer.tokenize if tokenizer else None))
samples_stats[i][StatsKeys.alpha_token_ratio] = (
samples_stats[idx][StatsKeys.alpha_token_ratio] = (
alpha_count / token_count) if token_count != 0 else 0.0
else:
if StatsKeys.alnum_ratio in stat:
continue
alnum_count = sum(
map(lambda char: 1
if char.isalnum() else 0, samples_list[i]))
samples_stats[i][StatsKeys.alnum_ratio] = (
alnum_count /
len(samples_list[i])) if len(samples_list[i]) != 0 else 0.0
map(lambda char: 1 if char.isalnum() else 0, cur_text))
samples_stats[idx][StatsKeys.alnum_ratio] = (
alnum_count / len(cur_text)) if len(cur_text) != 0 else 0.0

return samples

Expand Down
15 changes: 8 additions & 7 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,20 @@ def compute_stats(self, samples, context=False):
samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'

for i, stat in enumerate(samples_stats):
for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.avg_line_length in stat:
continue

if context and context_key in samples[Fields.context][i]:
lines = samples[Fields.context][i][context_key]
cur_text = samples_list[idx]
if context and context_key in samples[Fields.context][idx]:
lines = samples[Fields.context][idx][context_key]
else:
lines = samples_list[i].splitlines()
lines = cur_text.splitlines()
if context:
samples[Fields.context][i][context_key] = lines
samples_stats[i][StatsKeys.avg_line_length] = \
len(samples_list[i]) / len(lines) if len(lines) != 0 else 0.0
samples[Fields.context][idx][context_key] = lines
samples_stats[idx][StatsKeys.avg_line_length] = \
len(cur_text) / len(lines) if len(lines) != 0 else 0.0
return samples

def process(self, samples):
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def compute_stats(self, samples):
if StatsKeys.char_rep_ratio in stat:
continue

cur_text = samples_list[idx]
char_ngrams = [
samples_list[idx][i:i + self.n]
for i in range(len(samples_list[idx]) - self.n + 1)
cur_text[i:i + self.n]
for i in range(len(cur_text) - self.n + 1)
]
freq_char_ngrams = {}
for char_ngram in char_ngrams:
Expand Down
12 changes: 6 additions & 6 deletions data_juicer/ops/filter/maximum_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ def compute_stats(self, samples, context=False):
samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'

for i, stat in enumerate(samples_stats):
for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.max_line_length in stat:
continue

if context and context_key in samples[Fields.context][i]:
lines = samples[Fields.context][i][context_key]
if context and context_key in samples[Fields.context][idx]:
lines = samples[Fields.context][idx][context_key]
else:
lines = samples_list[i].splitlines()
lines = samples_list[idx].splitlines()
if context:
samples[Fields.context][i][context_key] = lines
samples[Fields.context][idx][context_key] = lines
line_lengths = list(map(len, lines))
samples_stats[i][StatsKeys.max_line_length] = max(
samples_stats[idx][StatsKeys.max_line_length] = max(
line_lengths) if line_lengths else 0

return samples
Expand Down
12 changes: 6 additions & 6 deletions data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@ def compute_stats(self, samples, context=False):
samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.sp_model_key}'

for i, stat in enumerate(samples_stats):
for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.perplexity in stat:
continue
# tokenization
if context and words_key in samples[Fields.context][i]:
words = samples[Fields.context][i][words_key]
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
tokenizer = get_model(self.sp_model_key)
words = get_words_from_document(
samples_list[i],
samples_list[idx],
token_func=tokenizer.encode_as_pieces
if tokenizer else None)
if context:
samples[Fields.context][i][words_key] = words
samples[Fields.context][idx][words_key] = words
text = ' '.join(words)
# compute perplexity
logits, length = 0, 0
Expand All @@ -76,7 +76,7 @@ def compute_stats(self, samples, context=False):
logits += kenlm_model.score(line)
length += (len(line.split()) + 1)
ppl = (10.0**(-logits / length)) if length != 0 else 0.0
samples_stats[i][StatsKeys.perplexity] = round(ppl, 1)
samples_stats[idx][StatsKeys.perplexity] = round(ppl, 1)

return samples

Expand Down
9 changes: 5 additions & 4 deletions data_juicer/ops/filter/special_characters_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ def compute_stats(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

for i, stat in enumerate(samples_stats):
for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.special_char_ratio in stat:
continue
cur_text = samples_list[idx]
# get ratio of special characters
samples_stats[i][StatsKeys.special_char_ratio] = (
len([c for c in samples_list[i] if c in SPECIAL_CHARACTERS]) /
len(samples_list[i])) if len(samples_list[i]) != 0 else 0.0
samples_stats[idx][StatsKeys.special_char_ratio] = (
len([c for c in cur_text if c in SPECIAL_CHARACTERS]) /
len(cur_text)) if len(cur_text) != 0 else 0.0

return samples

Expand Down
10 changes: 5 additions & 5 deletions data_juicer/ops/mapper/clean_email_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.repl = repl

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue
samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
samples[self.text_key][idx] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)

return samples
10 changes: 5 additions & 5 deletions data_juicer/ops/mapper/clean_ip_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.repl = repl

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue
samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
samples[self.text_key][idx] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
return samples
10 changes: 5 additions & 5 deletions data_juicer/ops/mapper/clean_links_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
self.repl = repl

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue

samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
samples[self.text_key][idx] = re.sub(pattern=self.pattern,
repl=self.repl,
string=text,
flags=re.DOTALL)
return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/expand_macro_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _build_non_arg_macros_dict(self, file_content):
return macros

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
non_arg_macros = self._build_non_arg_macros_dict(text)

# TODO: macros that take arguments are not supported yet
Expand All @@ -80,6 +80,6 @@ def process(self, samples):
for macro_name, macro_value in arg_macros.items():
pass

samples[self.text_key][i] = text
samples[self.text_key][idx] = text

return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/remove_comments_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self,
def process(self, samples):
# TODO: remove different comments by sample type

for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if self.inline:
# remove all in comments within a line
text = re.sub(pattern=r'[^\\]%.+$',
Expand All @@ -56,6 +56,6 @@ def process(self, samples):
string=text,
flags=re.MULTILINE)

samples[self.text_key][i] = text
samples[self.text_key][idx] = text

return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/remove_header_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, drop_no_head: bool = True, *args, **kwargs):
self.drop_no_head = drop_no_head

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
if self.drop_no_head:
text = ''
Expand All @@ -47,6 +47,6 @@ def process(self, samples):
string=text,
flags=re.DOTALL)

samples[self.text_key][i] = text
samples[self.text_key][idx] = text

return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/remove_long_words_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def should_keep_long_word(self, word):
return False

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
sentences = split_on_newline_tab_whitespace(text)
sentences = [[[
word for word in subsentence
if self.should_keep_long_word(word)
] for subsentence in sentence] for sentence in sentences]
samples[self.text_key][i] = merge_on_whitespace_tab_newline(
samples[self.text_key][idx] = merge_on_whitespace_tab_newline(
sentences)
return samples
10 changes: 5 additions & 5 deletions data_juicer/ops/mapper/remove_non_chinese_character_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def __init__(self,
self.pattern += u']'

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if not re.search(self.pattern, text, flags=re.DOTALL):
continue

samples[self.text_key][i] = re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL)
samples[self.text_key][idx] = re.sub(pattern=self.pattern,
repl=r'',
string=text,
flags=re.DOTALL)
return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/remove_repeat_sentences_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
) if ignore_special_character else None

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
lines = [e for e in text.split('\n')]
new_lines = []
hash_set = set([])
Expand All @@ -68,6 +68,6 @@ def process(self, samples):
hash_set.add(copy)
new_lines.append(new_sent)

samples[self.text_key][i] = '\n'.join(new_lines)
samples[self.text_key][idx] = '\n'.join(new_lines)

return samples
8 changes: 4 additions & 4 deletions data_juicer/ops/mapper/remove_table_text_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self,
self.pattern = r'(?<=\n)((\S+?)([ |\t](\S+?)){%d}\n+){2,}'

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for i in range(self.min_col - 1, self.max_col):
pattern = re.compile(self.pattern % i)
for idx, text in enumerate(samples[self.text_key]):
for idx in range(self.min_col - 1, self.max_col):
pattern = re.compile(self.pattern % idx)
text = pattern.sub('', text)

samples[self.text_key][i] = text
samples[self.text_key][idx] = text

return samples
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings):
return should_keep

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
if self.tokenization:
tokenizer = get_model(self.model_key)
sentences = get_words_from_document(
Expand All @@ -74,6 +74,6 @@ def process(self, samples):
] for subsentence in sentence] for sentence in sentences]
text = merge_on_whitespace_tab_newline(sentences)

samples[self.text_key][i] = text
samples[self.text_key][idx] = text

return samples
4 changes: 2 additions & 2 deletions data_juicer/ops/mapper/whitespace_normalization_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def process(self, samples):
for i, text in enumerate(samples[self.text_key]):
for idx, text in enumerate(samples[self.text_key]):
# remove whitespaces before and after the main content
text = text.strip()

# replace all kinds of whitespaces with ' '
samples[self.text_key][i] = ''.join([
samples[self.text_key][idx] = ''.join([
char if char not in VARIOUS_WHITESPACES else ' '
for char in text
])
Expand Down

0 comments on commit bc7ce9d

Please sign in to comment.