Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support batch_size>1 for some operators #406

Merged
merged 24 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,9 @@ def map(self, *args, **kargs):
# Batched is always required for fault tolerance
if inspect.ismethod(called_func):
kargs['batched'] = True
kargs['batch_size'] = kargs.pop(
'batch_size', 1) if called_func.__self__.is_batched_op() else 1
kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr(
called_func.__self__, 'is_batched_op'
) and called_func.__self__.is_batched_op() else 1

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
Expand Down Expand Up @@ -262,12 +263,25 @@ def filter(self, *args, **kargs):
args[0] = lambda x: nested_obj_factory(x)
else:
args[0] = wrap_func_with_nested_access(args[0])
called_func = args[0]
else:
if 'function' not in kargs or kargs['function'] is None:
kargs['function'] = lambda x: nested_obj_factory(x)
else:
kargs['function'] = wrap_func_with_nested_access(
kargs['function'])
called_func = kargs['function']

# For wrapped function, try to get its unwrapped (bound) method
while not inspect.ismethod(called_func) and hasattr(
called_func, '__wrapped__'):
called_func = called_func.__wrapped__

# Batched is always required for fault tolerance
if inspect.ismethod(
called_func) and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
Expand Down
6 changes: 4 additions & 2 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,16 @@ def _run_single_op(self, op):
self.num_proc, op.use_cuda())
num_gpus = get_num_gpus(op, op_proc)
try:
batch_size = getattr(op, 'batch_size',
1) if op.is_batched_op() else 1
if isinstance(op, Mapper):
self.data = self.data.map_batches(op.process,
batch_size=1,
batch_size=batch_size,
batch_format='pyarrow',
num_gpus=num_gpus)
elif isinstance(op, Filter):
self.data = self.data.map_batches(op.compute_stats,
batch_size=1,
batch_size=batch_size,
batch_format='pyarrow',
num_gpus=num_gpus)
if op.stats_export_path is not None:
Expand Down
5 changes: 5 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get('image_key', 'images')
self.audio_key = kwargs.get('audio_key', 'audios')
self.video_key = kwargs.get('video_key', 'videos')
self.batch_size = kwargs.get('batch_size', 1)
yxdyc marked this conversation as resolved.
Show resolved Hide resolved

# whether the model can be accelerated using cuda
_accelerator = kwargs.get('accelerator', None)
Expand Down Expand Up @@ -239,6 +240,7 @@ def run(self, dataset, *, exporter=None, tracer=None):
self.process,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
batch_size=self.batch_size,
desc=self._name + '_process',
)
if tracer:
Expand Down Expand Up @@ -302,15 +304,18 @@ def run(self, dataset, *, exporter=None, tracer=None):
'initial_value': {}
},
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc='Adding new column for stats')
dataset = dataset.map(self.compute_stats,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
batch_size=self.batch_size,
desc=self._name + '_compute_stats')
if self.stats_export_path is not None:
exporter.export_compute_stats(dataset, self.stats_export_path)
new_dataset = dataset.filter(self.process,
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc=self._name + '_process')
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
Expand Down
77 changes: 46 additions & 31 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class AlphanumericFilter(Filter):
"""Filter to keep samples with alphabet/numeric ratio within a specific
range."""

_batched_op = True

def __init__(self,
tokenization: bool = False,
min_ratio: float = 0.25,
Expand Down Expand Up @@ -54,36 +56,49 @@ def __init__(self,
pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped',
return_model=False)

def compute_stats(self, sample):
if self.tokenization:
if StatsKeys.alpha_token_ratio in sample[Fields.stats]:
return sample
alpha_count = sum(
map(lambda char: 1
if char.isalpha() else 0, sample[self.text_key]))
tokenizer = get_model(self.model_key)
token_count = len(
get_words_from_document(
sample[self.text_key],
token_func=tokenizer.tokenize if tokenizer else None))
sample[Fields.stats][StatsKeys.alpha_token_ratio] = (
alpha_count / token_count) if token_count != 0 else 0.0
else:
if StatsKeys.alnum_ratio in sample[Fields.stats]:
return sample
alnum_count = sum(
map(lambda char: 1
if char.isalnum() else 0, sample[self.text_key]))
sample[Fields.stats][StatsKeys.alnum_ratio] = (
alnum_count / len(sample[self.text_key])) if len(
sample[self.text_key]) != 0 else 0.0
return sample
def compute_stats(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

for idx, stat in enumerate(samples_stats):
cur_text = samples_list[idx]
if self.tokenization:
if StatsKeys.alpha_token_ratio in stat:
continue
alpha_count = sum(
map(lambda char: 1 if char.isalpha() else 0, cur_text))
tokenizer = get_model(self.model_key)
token_count = len(
get_words_from_document(
cur_text,
token_func=tokenizer.tokenize if tokenizer else None))
samples_stats[idx][StatsKeys.alpha_token_ratio] = (
alpha_count / token_count) if token_count != 0 else 0.0
else:
if StatsKeys.alnum_ratio in stat:
continue
alnum_count = sum(
map(lambda char: 1 if char.isalnum() else 0, cur_text))
samples_stats[idx][StatsKeys.alnum_ratio] = (
alnum_count / len(cur_text)) if len(cur_text) != 0 else 0.0

return samples

def process(self, sample):
ratio = sample[Fields.stats][
StatsKeys.alpha_token_ratio] if self.tokenization else sample[
Fields.stats][StatsKeys.alnum_ratio]
if self.min_ratio <= ratio <= self.max_ratio:
return True
def process(self, samples):
ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \
else StatsKeys.alnum_ratio
if isinstance(samples[Fields.stats], list):
yxdyc marked this conversation as resolved.
Show resolved Hide resolved
bool_results = []
for stat in samples[Fields.stats]:
if self.min_ratio <= stat[ratio_key] <= self.max_ratio:
bool_results.append(True)
BeachWang marked this conversation as resolved.
Show resolved Hide resolved
else:
bool_results.append(False)
return bool_results
else:
return False
# single sample for ray filter
if self.min_ratio <= samples[
Fields.stats][ratio_key] <= self.max_ratio:
return True
else:
return False
65 changes: 42 additions & 23 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
from ..base_op import OPERATORS, Filter
from ..op_fusion import INTER_LINES

OP_NAME = 'average_line_length_filter'

@OPERATORS.register_module('average_line_length_filter')
@INTER_LINES.register_module('average_line_length_filter')

@OPERATORS.register_module(OP_NAME)
@INTER_LINES.register_module(OP_NAME)
class AverageLineLengthFilter(Filter):
"""Filter to keep samples with average line length within a specific
range."""

_batched_op = True

def __init__(self,
min_len: PositiveInt = 10,
max_len: PositiveInt = sys.maxsize,
Expand All @@ -35,26 +39,41 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.avg_line_length in sample[Fields.stats]:
return sample

def compute_stats(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'
if context and context_key in sample[Fields.context]:
lines = sample[Fields.context][context_key]
else:
lines = sample[self.text_key].splitlines()
if context:
sample[Fields.context][context_key] = lines
sample[Fields.stats][StatsKeys.avg_line_length] = \
len(sample[self.text_key]) / len(lines) \
if len(lines) != 0 else 0.0
return sample

def process(self, sample):
if self.min_len <= sample[Fields.stats][
StatsKeys.avg_line_length] <= self.max_len:
return True

for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.avg_line_length in stat:
continue

cur_text = samples_list[idx]
if context and context_key in samples[Fields.context][idx]:
lines = samples[Fields.context][idx][context_key]
else:
lines = cur_text.splitlines()
if context:
samples[Fields.context][idx][context_key] = lines
samples_stats[idx][StatsKeys.avg_line_length] = \
len(cur_text) / len(lines) if len(lines) != 0 else 0.0
return samples

def process(self, samples):
if isinstance(samples[Fields.stats], list):
bool_results = []
for stat in samples[Fields.stats]:
if self.min_len <= stat[
StatsKeys.avg_line_length] <= self.max_len:
bool_results.append(True)
else:
bool_results.append(False)
return bool_results
else:
return False
# single sample for ray filter
if self.min_len <= samples[Fields.stats][
StatsKeys.avg_line_length] <= self.max_len:
return True
else:
return False
83 changes: 51 additions & 32 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class CharacterRepetitionFilter(Filter):
"""Filter to keep samples with char-level n-gram repetition ratio within a
specific range."""

_batched_op = True

def __init__(self,
rep_len: PositiveInt = 10,
min_ratio: ClosedUnitInterval = 0.0,
Expand All @@ -39,40 +41,57 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio

def compute_stats(self, sample):
# check if it's computed already
if StatsKeys.char_rep_ratio in sample[Fields.stats]:
return sample
def compute_stats(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

for idx, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.char_rep_ratio in stat:
continue

cur_text = samples_list[idx]
char_ngrams = [
cur_text[i:i + self.n]
for i in range(len(cur_text) - self.n + 1)
]
freq_char_ngrams = {}
for char_ngram in char_ngrams:
freq_char_ngrams[char_ngram] = (
freq_char_ngrams.get(char_ngram, 0) + 1)

char_ngrams = [
sample[self.text_key][i:i + self.n]
for i in range(len(sample[self.text_key]) - self.n + 1)
]
freq_char_ngrams = {}
for char_ngram in char_ngrams:
freq_char_ngrams[char_ngram] = (
freq_char_ngrams.get(char_ngram, 0) + 1)
if len(freq_char_ngrams) == 0:
samples_stats[idx][StatsKeys.char_rep_ratio] = 0.0
continue

if len(freq_char_ngrams) == 0:
sample[Fields.stats][StatsKeys.char_rep_ratio] = 0.0
return sample
freq_char_ngrams = sorted(list(freq_char_ngrams.values()),
reverse=True)
num_no_rep_char_ngrams = len(
[el for el in freq_char_ngrams if el == 1])
num_rep_char_ngrams = min(
int(np.sqrt(len(freq_char_ngrams))),
len(freq_char_ngrams) - num_no_rep_char_ngrams,
)
samples_stats[idx][StatsKeys.char_rep_ratio] = (
sum(freq_char_ngrams[:num_rep_char_ngrams]) /
sum(freq_char_ngrams)) if sum(freq_char_ngrams) != 0 else 0.0

freq_char_ngrams = sorted(list(freq_char_ngrams.values()),
reverse=True)
num_no_rep_char_ngrams = len(
[el for el in freq_char_ngrams if el == 1])
num_rep_char_ngrams = min(
int(np.sqrt(len(freq_char_ngrams))),
len(freq_char_ngrams) - num_no_rep_char_ngrams,
)
sample[Fields.stats][StatsKeys.char_rep_ratio] = (sum(
freq_char_ngrams[:num_rep_char_ngrams]) / sum(freq_char_ngrams)) \
if sum(freq_char_ngrams) != 0 else 0.0
return sample
return samples

def process(self, sample):
if self.min_ratio <= sample[Fields.stats][StatsKeys.char_rep_ratio] \
<= self.max_ratio:
return True
def process(self, samples):
if isinstance(samples[Fields.stats], list):
bool_results = []
for stat in samples[Fields.stats]:
if self.min_ratio <= stat[
StatsKeys.char_rep_ratio] <= self.max_ratio:
bool_results.append(True)
else:
bool_results.append(False)
return bool_results
else:
return False
# single sample for ray filter
if self.min_ratio <= samples[Fields.stats][
StatsKeys.char_rep_ratio] <= self.max_ratio:
return True
else:
return False
Loading
Loading