Skip to content

Commit

Permalink
* restore to batched version and rename to xxx_batched
Browse files Browse the repository at this point in the history
  • Loading branch information
HYLcool committed Oct 17, 2024
1 parent 281c68d commit 5bf3173
Show file tree
Hide file tree
Showing 86 changed files with 146 additions and 146 deletions.
44 changes: 22 additions & 22 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,23 +236,23 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.process_branch = catch_map_batches_exception(
self.process_batched)
self.process = catch_map_batches_exception(self.process_batched)
else:
self.process_branch = catch_map_single_exception(self.process)
self.process = catch_map_single_exception(self.process_single)

def process_batched(self, samples, *args, **kwargs):
keys = samples.keys()
first_key = list(keys)[0]
for i in range(len(samples[first_key])):
first_key = next(iter(keys))
num_samples = len(samples[first_key])
for i in range(num_samples):
this_sample = {key: samples[key][i] for key in keys}
res_sample = self.process(this_sample, *args, **kwargs)
res_sample = self.process_single(this_sample, *args, **kwargs)
for key in keys:
samples[key][i] = res_sample[key]

return samples

def process(self, sample):
def process_single(self, sample):
"""
For sample level, sample --> sample
Expand All @@ -264,7 +264,7 @@ def process(self, sample):
def run(self, dataset, *, exporter=None, tracer=None):
dataset = super(Mapper, self).run(dataset)
new_dataset = dataset.map(
self.process_branch,
self.process,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
batch_size=self.batch_size,
Expand Down Expand Up @@ -296,32 +296,32 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.compute_stats_branch = catch_map_batches_exception(
self.compute_stats = catch_map_batches_exception(
self.compute_stats_batched)
self.process_branch = catch_map_batches_exception(
self.process_batched)
self.process = catch_map_batches_exception(self.process_batched)
else:
self.compute_stats_branch = catch_map_single_exception(
self.compute_stats)
self.process_branch = catch_map_single_exception(self.process)
self.compute_stats = catch_map_single_exception(
self.compute_stats_single)
self.process = catch_map_single_exception(self.process_single)

def compute_stats_batched(self, samples, *args, **kwargs):
keys = samples.keys()
samples_stats = samples[Fields.stats]
for i in range(len(samples_stats)):
num_samples = len(samples[Fields.stats])
for i in range(num_samples):
this_sample = {key: samples[key][i] for key in keys}
res_sample = self.compute_stats(this_sample, *args, **kwargs)
res_sample = self.compute_stats_single(this_sample, *args,
**kwargs)
samples[Fields.stats][i] = res_sample[Fields.stats]
if 'context' in kwargs and kwargs['context']:
samples[Fields.context][i] = res_sample[Fields.context]

return samples

def process_batched(self, samples):
return map(lambda stat: self.process({Fields.stats: stat}),
return map(lambda stat: self.process_single({Fields.stats: stat}),
samples[Fields.stats])

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
"""
Compute stats for the sample which is used as a metric to decide
whether to filter this sample.
Expand All @@ -333,7 +333,7 @@ def compute_stats(self, sample, context=False):
"""
raise NotImplementedError

def process(self, sample):
def process_single(self, sample):
"""
For sample level, sample --> Boolean.
Expand All @@ -354,14 +354,14 @@ def run(self, dataset, *, exporter=None, tracer=None):
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc='Adding new column for stats')
dataset = dataset.map(self.compute_stats_branch,
dataset = dataset.map(self.compute_stats,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
batch_size=self.batch_size,
desc=self._name + '_compute_stats')
if exporter and self.stats_export_path is not None:
exporter.export_compute_stats(dataset, self.stats_export_path)
new_dataset = dataset.filter(self.process_branch,
new_dataset = dataset.filter(self.process,
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc=self._name + '_process')
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
pretrained_model_name_or_path='EleutherAI/pythia-6.9b-deduped',
return_model=False)

def compute_stats(self, samples):
def compute_stats_batched(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

Expand Down Expand Up @@ -79,7 +79,7 @@ def compute_stats(self, samples):

return samples

def process(self, samples):
def process_batched(self, samples):
ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \
else StatsKeys.alnum_ratio
if isinstance(samples[Fields.stats], list):
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/audio_duration_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.audio_duration in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -74,7 +74,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
audio_durations = sample[Fields.stats][StatsKeys.audio_duration]
keep_bools = np.array([
self.min_duration <= duration <= self.max_duration
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/audio_nmf_snr_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.audio_nmf_snr in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -124,7 +124,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
audio_snrs = sample[Fields.stats][StatsKeys.audio_nmf_snr]
keep_bools = np.array(
[self.min_snr <= snr <= self.max_snr for snr in audio_snrs])
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/audio_size_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.audio_sizes in sample[Fields.stats]:
return sample
Expand All @@ -58,7 +58,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
audio_sizes = sample[Fields.stats][StatsKeys.audio_sizes]
keep_bools = np.array([
self.min_size <= audio_size <= self.max_size
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self,
self.min_len = min_len
self.max_len = max_len

def compute_stats(self, samples, context=False):
def compute_stats_batched(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
context_key = f'{InterVars.lines}'
Expand All @@ -58,7 +58,7 @@ def compute_stats(self, samples, context=False):
len(cur_text) / len(lines) if len(lines) != 0 else 0.0
return samples

def process(self, samples):
def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_len <= stat[StatsKeys.avg_line_length] <=
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self,
self.min_ratio = min_ratio
self.max_ratio = max_ratio

def compute_stats(self, samples):
def compute_stats_batched(self, samples):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]

Expand Down Expand Up @@ -78,7 +78,7 @@ def compute_stats(self, samples):

return samples

def process(self, samples):
def process_batched(self, samples):
if isinstance(samples[Fields.stats], list):
return map(
lambda stat: self.min_ratio <= stat[StatsKeys.char_rep_ratio]
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.flagged_words_ratio in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -120,6 +120,6 @@ def compute_stats(self, sample, context=False):
StatsKeys.flagged_words_ratio] = flagged_words_ratio
return sample

def process(self, sample):
def process_single(self, sample):
return sample[Fields.stats][
StatsKeys.flagged_words_ratio] <= self.max_ratio
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_aesthetics_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
self.need_normalized_by_ten = ('shunk031/aesthetics-predictor'
in hf_scorer_model)

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_aesthetics_scores in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -107,7 +107,7 @@ def compute_stats(self, sample, rank=None, context=False):
aesthetics_scores
return sample

def process(self, sample):
def process_single(self, sample):
aesthetics_scores = (
sample)[Fields.stats][StatsKeys.image_aesthetics_scores]
if len(aesthetics_scores) <= 0:
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_aspect_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.aspect_ratios in sample[Fields.stats]:
return sample
Expand All @@ -66,7 +66,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios]
keep_bools = np.array([
self.min_ratio <= aspect_ratio <= self.max_ratio
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_face_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.face_ratios in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -112,7 +112,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
face_ratios = sample[Fields.stats][StatsKeys.face_ratios]
if len(face_ratios) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
pretrained_model_name_or_path=hf_nsfw_model,
trust_remote_code=trust_remote_code)

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_nsfw_score in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -84,7 +84,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
itm_scores = sample[Fields.stats][StatsKeys.image_nsfw_score]
if len(itm_scores) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_pair_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self,
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):

# check if it's computed already
if StatsKeys.image_pair_similarity in sample[Fields.stats]:
Expand Down Expand Up @@ -97,7 +97,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.image_pair_similarity]
if len(similarity) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_shape_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.image_width in sample[Fields.stats] \
and StatsKeys.image_height in sample[Fields.stats]:
Expand Down Expand Up @@ -76,7 +76,7 @@ def compute_stats(self, sample, context=False):
]
return sample

def process(self, sample):
def process_single(self, sample):
ws = sample[Fields.stats][StatsKeys.image_width]
hs = sample[Fields.stats][StatsKeys.image_height]
if len(ws) <= 0:
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_size_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats(self, sample, context=False):
def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.image_sizes in sample[Fields.stats]:
return sample
Expand All @@ -58,7 +58,7 @@ def compute_stats(self, sample, context=False):

return sample

def process(self, sample):
def process_single(self, sample):
image_sizes = sample[Fields.stats][StatsKeys.image_sizes]
keep_bools = np.array([
self.min_size <= image_size <= self.max_size
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_text_matching_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_text_matching_score in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -139,7 +139,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
itm_scores = sample[Fields.stats][StatsKeys.image_text_matching_score]
if len(itm_scores) <= 0:
return True
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/ops/filter/image_text_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,
self.horizontal_flip = horizontal_flip
self.vertical_flip = vertical_flip

def compute_stats(self, sample, rank=None, context=False):
def compute_stats_single(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_text_similarity in sample[Fields.stats]:
return sample
Expand Down Expand Up @@ -136,7 +136,7 @@ def compute_stats(self, sample, rank=None, context=False):

return sample

def process(self, sample, rank=None):
def process_single(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.image_text_similarity]
if len(similarity) <= 0:
return True
Expand Down
Loading

0 comments on commit 5bf3173

Please sign in to comment.