diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index 8509c1ba0..ae40b461c 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -118,22 +118,20 @@ def process(self, samples): texts_to_aug = samples[self.text_key][0] # batch_size = 1 res_samples = deepcopy(samples) + # get augmented texts if self.sequential: aug_texts = self.aug.augment(texts_to_aug, n=self.aug_num) - # add augmented samples to the batch with other replicate fields - for key in res_samples: - if key == self.text_key: - res_samples[self.text_key] += aug_texts - else: - res_samples[key] += res_samples[key] * len(aug_texts) else: # apply each aug method to generate several augmented texts + aug_texts = [] for aug_method in self.aug: - aug_texts = aug_method.augment(texts_to_aug, n=self.aug_num) - res_samples[self.text_key] += aug_texts - # add other replicate fields - for key in res_samples: - if key != self.text_key: - res_samples[key] = res_samples[key] * \ - len(res_samples[self.text_key]) + aug_texts += aug_method.augment(texts_to_aug, n=self.aug_num) + + # add augmented samples to the batch with other replicate fields + res_samples[self.text_key] += aug_texts + # add other replicate fields + for key in res_samples: + if key != self.text_key: + res_samples[key] = res_samples[key] * \ + len(res_samples[self.text_key]) return res_samples diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index 51cf50e49..3f10b2f58 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -125,6 +125,7 @@ def process(self, samples): texts_to_aug = samples[self.text_key] res_samples = deepcopy(samples) + # get augmented texts if self.sequential: aug_texts = texts_to_aug for aug_method in self.aug_pipeline: @@ -136,20 +137,17 @@ def process(self, samples): aug_texts = results[:] if len(aug_texts) == 1 and aug_texts[0] == texts_to_aug[0]: aug_texts = [] - # add augmented samples to the batch with other replicate fields - for key in res_samples: - if key == self.text_key: - res_samples[self.text_key] += aug_texts - else: - res_samples[key] += res_samples[key] * len(aug_texts) else: # apply each aug method to generate several augmented texts + aug_texts = [] for aug_method in self.aug_pipeline: - aug_texts = aug_method.replace(texts_to_aug[0])[1:] - res_samples[self.text_key] += aug_texts - # add other replicate fields - for key in res_samples: - if key != self.text_key: - res_samples[key] = res_samples[key] * \ - len(res_samples[self.text_key]) + aug_texts += aug_method.replace(texts_to_aug[0])[1:] + + # add augmented samples to the batch with other replicate fields + res_samples[self.text_key] += aug_texts + # add other replicate fields + for key in res_samples: + if key != self.text_key: + res_samples[key] = res_samples[key] * \ + len(res_samples[self.text_key]) return res_samples