Skip to content

Commit

Permalink
Add DPO training (#1242)
Browse files Browse the repository at this point in the history
* Add a chat data preprocessing script

* add EOT at end of a chat

* - add different packing impl (Unpacked, packing until overflow)
- fix labels to also have valid/test implementations
- fix label masking in _get_batch to also include anything from get_ltor_masks_and_position_ids

* update README.md

* - Add metrics to forward step to add DPO specific metrics that are useful (accuracy, etc)
- Add reference model setup for DPO
- Add pairwise dataset for positive/negative pairs
- Add DPO loss

* Update arguments.py to use train_label_data_paths instead of label_data_paths

* - Bugfixes from upstreaming....

* - add precompute logprobs...

* - Finishing up precompute logprobs...

* - update readme for DPO...

* fix varname

* Fix pipeline parallelism and incorrect neox_args name

* apply precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
dmahan93 and Quentin-Anthony authored Sep 8, 2024
1 parent ec82c05 commit 77e8158
Show file tree
Hide file tree
Showing 10 changed files with 1,145 additions and 97 deletions.
27 changes: 27 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,33 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in
"eval_iters": 10,
```

However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g.

```yaml
"dataset_impl": "pairwise",
"train_impl": "dpo",
"pack_impl": "unpacked",
"dpo_beta": 0.1,
"dpo_fp32": true,
"pos_train_data_path": "data/enwik8/enwik8_text_pos_document",
"pos_valid_data_path": "data/enwik8/enwik8_text_pos_document",
"pos_test_data_path": "data/enwik8/enwik8_text_pos_document",
"neg_train_data_path": "data/enwik8/enwik8_text_neg_document",
"neg_valid_data_path": "data/enwik8/enwik8_text_neg_document",
"neg_test_data_path": "data/enwik8/enwik8_text_neg_document",
## If you have labels... (likely to mask out user turns)
"pos_train_label_data_path": "data/enwik8/enwik8_text_pos_label_document",
"pos_valid_label_data_path": "data/enwik8/enwik8_text_pos_label_document",
"pos_test_label_data_path": "data/enwik8/enwik8_text_pos_label_document",
"neg_train_label_data_path": "data/enwik8/enwik8_text_neg_label_document",
"neg_valid_label_data_path": "data/enwik8/enwik8_text_neg_label_document",
"neg_test_label_data_path": "data/enwik8/enwik8_text_neg_label_document",
## If you want to precompute the logits over your dataset...
"precompute_model_name": "gpt2",
## Needed for the generation.py step, if precomputing
"text_gen_type": "precompute"
```

### LR Scheduler settings

```yaml
Expand Down
3 changes: 3 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
generate_samples_from_prompt,
generate_samples_unconditional,
generate_samples_interactive,
precompute_logits,
)


Expand Down Expand Up @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None):
top_p=neox_args.top_p,
)

elif neox_args.text_gen_type == "precompute":
precompute_logits(neox_args=neox_args, model=model)
else:
raise ValueError(
f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}"
Expand Down
178 changes: 151 additions & 27 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.pairwise_dataset import PairwiseDataset
from megatron.data.samplers import DistributedBatchSampler


Expand Down Expand Up @@ -53,43 +54,113 @@ def make_data_loader(dataset, neox_args):

def build_the_dataset(
data_prefix,
pos_data_prefix,
neg_data_prefix,
name,
data_impl,
pack_impl,
dataset_impl,
allow_chopped,
num_samples,
seq_length,
seed,
skip_warmup,
build_index_mappings=True,
label_prefix=None,
pos_label_prefix=None,
neg_label_prefix=None,
precompute_model_name=None,
):
"""Build train/valid/test datasets."""

indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
if dataset_impl == "gpt2":
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
if precompute_model_name is not None:
# If we have the name, assume it exists. If it doesn't, it will just be None which is fine.
precompute_indexed_dataset = make_indexed_dataset(
data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
precompute_indexed_dataset = precompute_indexed_dataset
elif dataset_impl == "pairwise":
pos_indexed_dataset = make_indexed_dataset(
pos_data_prefix, data_impl, skip_warmup
)
neg_indexed_dataset = make_indexed_dataset(
neg_data_prefix, data_impl, skip_warmup
)
if pos_label_prefix is None:
pos_label_dataset = None
# Also do neg here since they both must be the same
assert neg_label_prefix is None
neg_label_dataset = None
else:
pos_label_dataset = make_indexed_dataset(
pos_label_prefix, data_impl, skip_warmup
)
# Also do neg here since they both must be the same
assert neg_label_prefix is not None
neg_label_dataset = make_indexed_dataset(
neg_label_prefix, data_impl, skip_warmup
)
if precompute_model_name is None:
pos_ref_dataset = None
neg_ref_dataset = None
else:
pos_ref_dataset = make_indexed_dataset(
pos_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
neg_ref_dataset = make_indexed_dataset(
neg_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented")

total_num_of_documents = indexed_dataset.sizes.shape[0]
total_num_of_documents = (
indexed_dataset.sizes.shape[0]
if dataset_impl == "gpt2"
else pos_indexed_dataset.sizes.shape[0]
)
print_rank_0(" {}:".format(name))
print_rank_0(" no. of documents:{}".format(total_num_of_documents))
dataset = None
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32)
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)

if dataset_impl == "gpt2":
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
elif dataset_impl == "pairwise":
dataset = PairwiseDataset(
name,
pos_data_prefix,
documents,
pos_indexed_dataset,
neg_indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
pos_label_dataset=pos_label_dataset,
neg_label_dataset=neg_label_dataset,
pos_ref_dataset=pos_ref_dataset,
neg_ref_dataset=neg_ref_dataset,
)

return dataset


Expand Down Expand Up @@ -135,7 +206,6 @@ def build_dataset(index, name):
documents = np.arange(
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
)

dataset = GPT2Dataset(
name,
data_prefix,
Expand Down Expand Up @@ -219,21 +289,57 @@ def build_weighted_datasets(
valid_label_path,
test_path,
test_label_path,
pos_train_path,
neg_train_path,
pos_train_label_path,
neg_train_label_path,
pos_valid_path,
neg_valid_path,
pos_valid_label_path,
neg_valid_label_path,
pos_test_path,
neg_test_path,
pos_test_label_path,
neg_test_label_path,
) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.train_data_paths if neox_args.train_data_paths else [],
neox_args.train_label_data_paths
if neox_args.train_label_data_paths
else [],
neox_args.valid_data_paths,
neox_args.valid_data_paths if neox_args.valid_data_paths else [],
neox_args.valid_label_data_paths
if neox_args.valid_label_data_paths
else [],
neox_args.test_data_paths,
neox_args.test_data_paths if neox_args.test_data_paths else [],
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [],
neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [],
neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [],
neox_args.pos_train_label_data_paths
if neox_args.pos_train_label_data_paths
else [],
neox_args.neg_train_label_data_paths
if neox_args.neg_train_label_data_paths
else [],
neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [],
neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [],
neox_args.pos_valid_label_data_paths
if neox_args.pos_valid_label_data_paths
else [],
neox_args.neg_valid_label_data_paths
if neox_args.neg_valid_label_data_paths
else [],
neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [],
neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [],
neox_args.pos_test_label_data_paths
if neox_args.pos_test_label_data_paths
else [],
neox_args.neg_test_label_data_paths
if neox_args.neg_test_label_data_paths
else [],
)
):
if train_path:
if train_path or pos_train_path:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
Expand All @@ -247,10 +353,16 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=train_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_train_path,
neg_data_prefix=neg_train_path,
pos_label_prefix=pos_train_label_path,
neg_label_prefix=neg_train_label_path,
precompute_model_name=neox_args.precompute_model_name,
)
)

if valid_path:
if valid_path or pos_valid_path:
valid_datasets.append(
build_the_dataset(
data_prefix=valid_path,
Expand All @@ -264,10 +376,16 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=valid_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_valid_path,
neg_data_prefix=neg_valid_path,
pos_label_prefix=pos_valid_label_path,
neg_label_prefix=neg_valid_label_path,
precompute_model_name=neox_args.precompute_model_name,
)
)

if test_path:
if test_path or pos_test_path:
test_datasets.append(
build_the_dataset(
data_prefix=test_path,
Expand All @@ -281,6 +399,12 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=test_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_test_path,
neg_data_prefix=neg_test_path,
pos_label_prefix=pos_test_label_path,
neg_label_prefix=neg_test_label_path,
precompute_model_name=neox_args.precompute_model_name,
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down Expand Up @@ -352,7 +476,7 @@ def build_train_valid_test_data_iterators(neox_args):
test_iters * neox_args.train_batch_size,
]

if neox_args.train_data_paths:
if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths):
# when individual train / valid / test data paths are provided
# normalize weight values and get num samples for each dataset
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
Expand Down
Loading

0 comments on commit 77e8158

Please sign in to comment.