Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for batched processing #448

Merged
merged 14 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,33 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.process = catch_map_batches_exception(self.process)
self.process = catch_map_batches_exception(self.process_batched)
HYLcool marked this conversation as resolved.
Show resolved Hide resolved
else:
self.process = catch_map_single_exception(self.process)

def process(self, sample):
self.process = catch_map_single_exception(self.process_single)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
not_allowed_list = ['process']
for method_name in not_allowed_list:
if method_name in cls.__dict__:
raise TypeError(
f'Method {method_name} cannot be overridden by subclass '
f'{cls.__name__}. Please implement {method_name}_single '
f'or {method_name}_batched.')

def process_batched(self, samples, *args, **kwargs):
keys = samples.keys()
first_key = next(iter(keys))
num_samples = len(samples[first_key])
for i in range(num_samples):
this_sample = {key: samples[key][i] for key in keys}
res_sample = self.process_single(this_sample, *args, **kwargs)
for key in keys:
samples[key][i] = res_sample[key]

return samples

def process_single(self, sample):
"""
For sample level, sample --> sample

Expand Down Expand Up @@ -285,11 +307,41 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
self.compute_stats = catch_map_batches_exception(
self.compute_stats)
self.compute_stats_batched)
self.process = catch_map_batches_exception(self.process_batched)
else:
self.compute_stats = catch_map_single_exception(self.compute_stats)

def compute_stats(self, sample, context=False):
self.compute_stats = catch_map_single_exception(
self.compute_stats_single)
self.process = catch_map_single_exception(self.process_single)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
not_allowed_list = ['compute_stats', 'process']
for method_name in not_allowed_list:
if method_name in cls.__dict__:
raise TypeError(
f'Method {method_name} cannot be overridden by subclass '
f'{cls.__name__}. Please implement {method_name}_single '
f'or {method_name}_batched.')

def compute_stats_batched(self, samples, *args, **kwargs):
keys = samples.keys()
num_samples = len(samples[Fields.stats])
for i in range(num_samples):
this_sample = {key: samples[key][i] for key in keys}
res_sample = self.compute_stats_single(this_sample, *args,
**kwargs)
samples[Fields.stats][i] = res_sample[Fields.stats]
if 'context' in kwargs and kwargs['context']:
samples[Fields.context][i] = res_sample[Fields.context]

return samples

def process_batched(self, samples):
return map(lambda stat: self.process_single({Fields.stats: stat}),
samples[Fields.stats])

def compute_stats_single(self, sample, context=False):
"""
Compute stats for the sample which is used as a metric to decide
whether to filter this sample.
Expand All @@ -301,7 +353,7 @@ def compute_stats(self, sample, context=False):
"""
raise NotImplementedError

def process(self, sample):
def process_single(self, sample):
"""
For sample level, sample --> Boolean.

Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/deduplicator/ray_basic_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def calculate_hash(self, sample, context=False):
"""Calculate hash value for the sample."""
raise NotImplementedError

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# init redis client
r = redis.StrictRedis(host=self.redis_host, port=self.redis_port, db=0)
# compute hash
Expand All @@ -59,5 +59,5 @@ def compute_stats(self, sample, context=False):
sample[HashKeys.is_duplicate] = r.setnx(md5_value, 1)
return sample

def process(self, sample):
def process_single(self, sample):
return sample[HashKeys.is_duplicate]
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped',
return_model=False)

def compute_stats(self, samples):
def compute_stats_batched(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

Expand Down Expand Up @@ -79,7 +79,7 @@ def compute_stats(self, samples):

return samples

def process(self, samples):
def process_batched(self, samples):
ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \
else StatsKeys.alnum_ratio
if isinstance(samples[Fields.stats], list):
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/audio_duration_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.audio_duration in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -74,7 +74,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
audio_durations = sample[Fields.stats][StatsKeys.audio_duration]
keep_bools = np.array([
self.min_duration <= duration <= self.max_duration
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/audio_nmf_snr_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.audio_nmf_snr in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
audio_snrs = sample[Fields.stats][StatsKeys.audio_nmf_snr]
keep_bools = np.array(
[self.min_snr <= snr <= self.max_snr for snr in audio_snrs])
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/audio_size_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.audio_sizes in sample[Fields.stats]:
return sample
Expand All @@ -58,7 +58,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
audio_sizes = sample[Fields.stats][StatsKeys.audio_sizes]
keep_bools = np.array([
self.min_size <= audio_size <= self.max_size
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len

def compute_stats(self, samples, context=False):
def compute_stats_batched(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'
Expand All @@ -58,7 +58,7 @@ def compute_stats(self, samples, context=False):
len(cur_text) / len(lines) if len(lines) != 0 else 0.0
return samples

def process(self, samples):
def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_len <= stat[StatsKeys.avg_line_length] <=
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio

def compute_stats(self, samples):
def compute_stats_batched(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

Expand Down Expand Up @@ -78,7 +78,7 @@ def compute_stats(self, samples):

return samples

def process(self, samples):
def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.char_rep_ratio]
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.flagged_words_ratio in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -120,6 +120,6 @@ def compute_stats(self, sample, context=False):
StatsKeys.flagged_words_ratio] = flagged_words_ratio
return sample

def process(self, sample):
def process_single(self, sample):
return sample[Fields.stats][
StatsKeys.flagged_words_ratio] <= self.max_ratio
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_aesthetics_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
self.need_normalized_by_ten = ('shunk031/aesthetics-predictor'
in hf_scorer_model)

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_aesthetics_scores in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -107,7 +107,7 @@ def compute_stats(self, sample, rank=None, context=False):
aesthetics_scores
return sample

def process(self, sample):
def process_single(self, sample):
aesthetics_scores = (
sample)[Fields.stats][StatsKeys.image_aesthetics_scores]
if len(aesthetics_scores) <= 0:
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_aspect_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.aspect_ratios in sample[Fields.stats]:
return sample
Expand All @@ -66,7 +66,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios]
keep_bools = np.array([
self.min_ratio <= aspect_ratio <= self.max_ratio
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_face_count_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.face_ratios in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -109,7 +109,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
face_counts = sample[Fields.stats][StatsKeys.face_counts]
if len(face_counts) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_face_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.face_ratios in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -112,7 +112,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
face_ratios = sample[Fields.stats][StatsKeys.face_ratios]
if len(face_ratios) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
pretrained_model_name_or_path=hf_nsfw_model,
trust_remote_code=trust_remote_code)

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_nsfw_score in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -84,7 +84,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
itm_scores = sample[Fields.stats][StatsKeys.image_nsfw_score]
if len(itm_scores) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_pair_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self,
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):

# check if it's computed already
if StatsKeys.image_pair_similarity in sample[Fields.stats]:
Expand Down Expand Up @@ -97,7 +97,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.image_pair_similarity]
if len(similarity) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_shape_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.image_width in sample[Fields.stats] \
and StatsKeys.image_height in sample[Fields.stats]:
Expand Down Expand Up @@ -76,7 +76,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
ws = sample[Fields.stats][StatsKeys.image_width]
hs = sample[Fields.stats][StatsKeys.image_height]
if len(ws) <= 0:
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_size_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.image_sizes in sample[Fields.stats]:
return sample
Expand All @@ -58,7 +58,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
image_sizes = sample[Fields.stats][StatsKeys.image_sizes]
keep_bools = np.array([
self.min_size <= image_size <= self.max_size
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_text_matching_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_text_matching_score in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -139,7 +139,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
itm_scores = sample[Fields.stats][StatsKeys.image_text_matching_score]
if len(itm_scores) <= 0:
return True
Expand Down
Loading
Loading