Skip to content

Commit

Permalink
feature: add max_samples to limit mixed datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijianma committed Nov 17, 2023
1 parent afe06dc commit 81f5533
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
59 changes: 50 additions & 9 deletions data_juicer/format/mixture_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self,
suffixes: Union[str, List[str], Tuple[str]] = None,
text_keys=None,
add_suffix=False,
max_samples=None,
**kwargs):
"""
Initialization method.
Expand All @@ -28,9 +29,30 @@ def __init__(self,
:param text_keys: key names of field that stores sample text.
:param add_suffix: whether to add the file suffix to dataset
meta info
:param max_samples: max samples number of mixed dataset.
:param kwargs: extra args
"""

data_prefixes, weights = self._get_weight(data_prefix=dataset_path)
sample_numbers = [0] * len(weights)
if max_samples is not None:
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
sample_num_per_dataset = [
int(np.ceil(max_samples * weight)) for weight in weights
]

# Adjust
acc_sample_numbers = 0
for i in range(len(sample_num_per_dataset)):
sample_numbers[i] = min(sample_num_per_dataset[i],
max_samples - acc_sample_numbers)
acc_sample_numbers += sample_numbers[i]

self.sample_numbers = sample_numbers
self.weights = weights
self.formatters = [
load_formatter(dataset_path=data_prefix,
Expand Down Expand Up @@ -65,21 +87,38 @@ def _get_weight(self, data_prefix):
prefixes.append(value)
return prefixes, weights

def _random_sample(self, dataset, weight=1.0, seed=None):
def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None):
"""
Randomly sample a subset from a dataset with weight.
Randomly sample a subset from a dataset with weight or number,
if sample number is bigger than 0, we will use sample
number instead of weight.
:param dataset: a HuggingFace dataset
:param weight: sample ratio of dataset
:param sample_number: sample number of dataset
:param seed: random sample seed, if None, 42 as default
:return: a subset of dataset
"""
if seed is None:
seed = 42
num_samples = min(int(np.ceil(dataset.num_rows * weight)),
dataset.num_rows)
if num_samples == dataset.num_rows:

ds_samples = dataset.num_rows
if sample_number <= 0:
sample_number = int(np.ceil(ds_samples * weight))

if sample_number == ds_samples:
return dataset
return dataset.shuffle(seed=seed).select(range(num_samples))

num_epochs = int(np.ceil(sample_number / ds_samples)) - 1

if num_epochs > 0:
remain_samples = sample_number - num_epochs * ds_samples
sample_index = list(range(ds_samples)) * num_epochs + list(
range(remain_samples))
else:
remain_samples = sample_number
sample_index = list(range(remain_samples))

return dataset.shuffle(seed=seed).select(sample_index)

def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Expand All @@ -90,11 +129,13 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
:return: mixed dataset
"""
dataset_list = []
for weight, formatter in zip(self.weights, self.formatters):
for weight, sample_num, formatter in zip(self.weights,
self.sample_numbers,
self.formatters):
dataset = formatter.load_dataset(num_proc, global_cfg)
sampled = self._random_sample(dataset, weight)
sampled = self._random_sample(dataset, weight, sample_num)
logger.info(f'sampled {len(sampled)} from '
f'{len(dataset)} with weight {weight}')
f'{len(dataset)}')
dataset_list.append(sampled)

from data_juicer.core.data import NestedDataset
Expand Down
7 changes: 6 additions & 1 deletion tools/postprocess/data_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def parse_args():
'size of each shard won\'t larger than the '
'export_shard_size')

parser.add_argument('--max_samples',
type=int,
default=None,
help='Number of samples of mixed dataset.')

parser.add_argument('--num_proc',
type=int,
default=4,
Expand All @@ -58,7 +63,7 @@ def run_mixture():
"""
args = parse_args()
data_path = ' '.join(args.data_path)
formatter = load_formatter(data_path)
formatter = load_formatter(data_path, max_samples=args.max_samples)
dataset = formatter.load_dataset(args.num_proc)
exporter = Exporter(export_path=args.export_path,
export_shard_size=args.export_shard_size,
Expand Down

0 comments on commit 81f5533

Please sign in to comment.