diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index f55907f90..d8cc6b9ad 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -17,6 +17,7 @@ def __init__(self, suffixes: Union[str, List[str], Tuple[str]] = None, text_keys=None, add_suffix=False, + max_samples=None, **kwargs): """ Initialization method. @@ -28,9 +29,30 @@ def __init__(self, :param text_keys: key names of field that stores sample text. :param add_suffix: whether to add the file suffix to dataset meta info + :param max_samples: max samples number of mixed dataset. :param kwargs: extra args """ + data_prefixes, weights = self._get_weight(data_prefix=dataset_path) + sample_numbers = [0] * len(weights) + if max_samples is not None: + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + sample_num_per_dataset = [ + int(np.ceil(max_samples * weight)) for weight in weights + ] + + # Adjust + acc_sample_numbers = 0 + for i in range(len(sample_num_per_dataset)): + sample_numbers[i] = min(sample_num_per_dataset[i], + max_samples - acc_sample_numbers) + acc_sample_numbers += sample_numbers[i] + + self.sample_numbers = sample_numbers self.weights = weights self.formatters = [ load_formatter(dataset_path=data_prefix, @@ -65,21 +87,38 @@ def _get_weight(self, data_prefix): prefixes.append(value) return prefixes, weights - def _random_sample(self, dataset, weight=1.0, seed=None): + def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None): """ - Randomly sample a subset from a dataset with weight. + Randomly sample a subset from a dataset with weight or number, + if sample number is bigger than 0, we will use sample + number instead of weight. :param dataset: a HuggingFace dataset :param weight: sample ratio of dataset + :param sample_number: sample number of dataset :param seed: random sample seed, if None, 42 as default :return: a subset of dataset """ if seed is None: seed = 42 - num_samples = min(int(np.ceil(dataset.num_rows * weight)), - dataset.num_rows) - if num_samples == dataset.num_rows: + + ds_samples = dataset.num_rows + if sample_number <= 0: + sample_number = int(np.ceil(ds_samples * weight)) + + if sample_number == ds_samples: return dataset - return dataset.shuffle(seed=seed).select(range(num_samples)) + + num_epochs = int(np.ceil(sample_number / ds_samples)) - 1 + + if num_epochs > 0: + remain_samples = sample_number - num_epochs * ds_samples + sample_index = list(range(ds_samples)) * num_epochs + list( + range(remain_samples)) + else: + remain_samples = sample_number + sample_index = list(range(remain_samples)) + + return dataset.shuffle(seed=seed).select(sample_index) def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: """ @@ -90,11 +129,13 @@ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset: :return: mixed dataset """ dataset_list = [] - for weight, formatter in zip(self.weights, self.formatters): + for weight, sample_num, formatter in zip(self.weights, + self.sample_numbers, + self.formatters): dataset = formatter.load_dataset(num_proc, global_cfg) - sampled = self._random_sample(dataset, weight) + sampled = self._random_sample(dataset, weight, sample_num) logger.info(f'sampled {len(sampled)} from ' - f'{len(dataset)} with weight {weight}') + f'{len(dataset)}') dataset_list.append(sampled) from data_juicer.core.data import NestedDataset diff --git a/tools/postprocess/data_mixture.py b/tools/postprocess/data_mixture.py index 146986976..db89a2a1f 100644 --- a/tools/postprocess/data_mixture.py +++ b/tools/postprocess/data_mixture.py @@ -33,6 +33,11 @@ def parse_args(): 'size of each shard won\'t larger than the ' 'export_shard_size') + parser.add_argument('--max_samples', + type=int, + default=None, + help='Number of samples of mixed dataset.') + parser.add_argument('--num_proc', type=int, default=4, @@ -58,7 +63,7 @@ def run_mixture(): """ args = parse_args() data_path = ' '.join(args.data_path) - formatter = load_formatter(data_path) + formatter = load_formatter(data_path, max_samples=args.max_samples) dataset = formatter.load_dataset(args.num_proc) exporter = Exporter(export_path=args.export_path, export_shard_size=args.export_shard_size,