Skip to content

Commit

Permalink
* reconstruct the code structure of two aug mappers
Browse files Browse the repository at this point in the history
  • Loading branch information
HYLcool committed Nov 15, 2023
1 parent baddb14 commit 35c5a8c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 26 deletions.
24 changes: 11 additions & 13 deletions data_juicer/ops/mapper/nlpaug_en_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,20 @@ def process(self, samples):
texts_to_aug = samples[self.text_key][0] # batch_size = 1
res_samples = deepcopy(samples)

# get augmented texts
if self.sequential:
aug_texts = self.aug.augment(texts_to_aug, n=self.aug_num)
# add augmented samples to the batch with other replicate fields
for key in res_samples:
if key == self.text_key:
res_samples[self.text_key] += aug_texts
else:
res_samples[key] += res_samples[key] * len(aug_texts)
else:
# apply each aug method to generate several augmented texts
aug_texts = []
for aug_method in self.aug:
aug_texts = aug_method.augment(texts_to_aug, n=self.aug_num)
res_samples[self.text_key] += aug_texts
# add other replicate fields
for key in res_samples:
if key != self.text_key:
res_samples[key] = res_samples[key] * \
len(res_samples[self.text_key])
aug_texts += aug_method.augment(texts_to_aug, n=self.aug_num)

# add augmented samples to the batch with other replicate fields
res_samples[self.text_key] += aug_texts
# add other replicate fields
for key in res_samples:
if key != self.text_key:
res_samples[key] = res_samples[key] * \
len(res_samples[self.text_key])
return res_samples
24 changes: 11 additions & 13 deletions data_juicer/ops/mapper/nlpcda_zh_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def process(self, samples):
texts_to_aug = samples[self.text_key]
res_samples = deepcopy(samples)

# get augmented texts
if self.sequential:
aug_texts = texts_to_aug
for aug_method in self.aug_pipeline:
Expand All @@ -136,20 +137,17 @@ def process(self, samples):
aug_texts = results[:]
if len(aug_texts) == 1 and aug_texts[0] == texts_to_aug[0]:
aug_texts = []
# add augmented samples to the batch with other replicate fields
for key in res_samples:
if key == self.text_key:
res_samples[self.text_key] += aug_texts
else:
res_samples[key] += res_samples[key] * len(aug_texts)
else:
# apply each aug method to generate several augmented texts
aug_texts = []
for aug_method in self.aug_pipeline:
aug_texts = aug_method.replace(texts_to_aug[0])[1:]
res_samples[self.text_key] += aug_texts
# add other replicate fields
for key in res_samples:
if key != self.text_key:
res_samples[key] = res_samples[key] * \
len(res_samples[self.text_key])
aug_texts += aug_method.replace(texts_to_aug[0])[1:]

# add augmented samples to the batch with other replicate fields
res_samples[self.text_key] += aug_texts
# add other replicate fields
for key in res_samples:
if key != self.text_key:
res_samples[key] = res_samples[key] * \
len(res_samples[self.text_key])
return res_samples

0 comments on commit 35c5a8c

Please sign in to comment.