Skip to content

Commit

Permalink
Add Reward Model training (#1246)
Browse files Browse the repository at this point in the history
* Add a chat data preprocessing script

* add EOT at end of a chat

* - add different packing impl (Unpacked, packing until overflow)
- fix labels to also have valid/test implementations
- fix label masking in _get_batch to also include anything from get_ltor_masks_and_position_ids

* update README.md

* - Add metrics to forward step to add DPO specific metrics that are useful (accuracy, etc)
- Add reference model setup for DPO
- Add pairwise dataset for positive/negative pairs
- Add DPO loss

* Update arguments.py to use train_label_data_paths instead of label_data_paths

* - Bugfixes from upstreaming....

* - add precompute logprobs...

* - Finishing up precompute logprobs...

* - update readme for DPO...

* - Add RM training

* add comment on why row-parallel for RMs

* fix var name

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
dmahan93 and Quentin-Anthony authored Sep 9, 2024
1 parent 61a3daa commit 1c72742
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 27 deletions.
1 change: 1 addition & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def load_checkpoint(
load_lr_scheduler_states=load_optim_and_scheduler,
load_module_only=not load_optim_and_scheduler,
tag=tag,
load_module_strict=neox_args.train_impl != "rm",
)

if checkpoint_name is None:
Expand Down
54 changes: 35 additions & 19 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def __init__(
is_last_layer=False,
):
super().__init__()
parallelism = neox_args.output_layer_parallelism
self.is_rm = neox_args.train_impl == "rm"
parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row"
if parallelism == "column":
self.final_linear = mpu.ColumnParallelLinear(
neox_args=neox_args,
Expand All @@ -212,26 +213,41 @@ def __init__(
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0
)

# else:
# print(
# 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
# )
# exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )
else:
if not self.is_rm:
print(
'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
)
exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )
else: # Not using cross entropy loss for RMs
self.rm_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=1,
bias=False,
input_is_parallel=False,
init_method=init_method,
parallel_output=False,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
)

def forward(self, hidden_states):
return self.final_linear(hidden_states)
if not self.is_rm:
return self.final_linear(hidden_states)
else:
return self.rm_linear(hidden_states)


class _MegablocksAdapter(nn.Module):
Expand Down
11 changes: 9 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,9 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Dataset implementation, can be one of "gpt2" or "pairwise"
"""

train_impl: Literal["normal", "dpo"] = "normal"
train_impl: Literal["normal", "dpo", "rm"] = "normal"
"""
Training implementation, can be one of "normal" or "dpo"
Training implementation, can be one of "normal", "dpo", or "rm"
"""

dpo_fp32: bool = True
Expand All @@ -1012,6 +1012,13 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Beta value for DPO
"""

z_loss: float = 0.0
"""
Z-loss parameter, only implemented for RM training currently.
https://arxiv.org/pdf/2204.02311
https://arxiv.org/pdf/2309.10305
"""

allow_chopped: bool = True
"""
WARNING: if your packing impl is packed, this is ignored.
Expand Down
36 changes: 31 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def get_batch(neox_args, data_iterator):
# Items and their type.
if neox_args.train_impl == "normal":
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl == "dpo":
elif neox_args.train_impl in ["dpo", "rm"]:
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_train_label_data_paths
Expand All @@ -338,7 +338,7 @@ def get_batch(neox_args, data_iterator):
data=data,
datatype=datatype,
)
elif neox_args.train_impl == "dpo":
elif neox_args.train_impl in ["dpo", "rm"]:
pos_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand All @@ -353,7 +353,7 @@ def get_batch(neox_args, data_iterator):
data=data,
datatype=datatype,
)
if neox_args.precompute_model_name:
if (neox_args.precompute_model_name) and (neox_args.train_impl == "dpo"):
ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float)
else:
ref_data = {"pos_ref": None}
Expand Down Expand Up @@ -491,7 +491,7 @@ def forward_step(
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
if neox_args.train_impl == "dpo":
if neox_args.train_impl in ["dpo", "rm"]:
tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
Expand Down Expand Up @@ -532,6 +532,32 @@ def forward_step(
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_impl == "rm":
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, _ = maybe_tuple
else:
outputs = maybe_tuple
pos, neg = torch.chunk(outputs, 2, 0)
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
# We assume that each pos, neg pair occur in the same order
# e.g. second nonzero pos is the corresponding second nonzero neg
# and that there are also an equal number of pos and neg in each sequence.
pos_indx = pos_loss_mask.nonzero()
neg_indx = neg_loss_mask.nonzero()
# indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index.
pos_indx = pos_indx[:, 1].unsqueeze(1)
neg_indx = neg_indx[:, 1].unsqueeze(1)
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx)
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx)
with torch.no_grad():
metrics["pos_values"] = pos.clone().detach().mean()
metrics["neg_values"] = neg.clone().detach().mean()
metrics["margin"] = (pos - neg).clone().detach().mean()
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean()
loss = (-F.logsigmoid(pos - neg).mean()) + (
(neox_args.z_loss * (pos**2 + neg**2)).mean()
)
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.no_grad():
Expand Down Expand Up @@ -616,7 +642,7 @@ def get_model(neox_args, use_cache=False):
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down
25 changes: 24 additions & 1 deletion tools/datasets/preprocess_data_with_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def build_chat(
apply_mask: bool,
tokenizer: PreTrainedTokenizer,
only_last_turn: bool = False,
for_rm: bool = False,
) -> Tuple[List[int], List[int]]:
"""
Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the
Expand All @@ -91,12 +92,28 @@ def build_chat(
:param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss
:param tokenizer: A HF tokenizer
:param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks
:param for_rm: Whether this is for a reward model or not, this will mask everything except EOS token.
If you need a more complicated setup, you can modify this function to suit your needs.
"""
tokens = []
mask = []
if apply_mask is False:
tokens = tokenizer.apply_chat_template(chat)
mask = tokens
if tokenizer.eos_token_id is not None:
mask.append(tokenizer.eos_token_id)
tokens.append(tokenizer.eos_token_id)
return tokens, mask
elif for_rm:
tokens = tokenizer.apply_chat_template(chat)
mask = [-100] * len(tokens)
if tokenizer.eos_token_id is not None:
mask.append(tokenizer.eos_token_id)
tokens.append(tokenizer.eos_token_id)
else:
raise ValueError(
"Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own."
)
return tokens, mask
for i, turn in enumerate(chat):
add_gen = (
Expand All @@ -105,7 +122,7 @@ def build_chat(
chat_tokens = tokenizer.apply_chat_template(
chat[: i + 1], add_generation_prompt=add_gen
)[len(tokens) :]

# remove previous stuff...
tokens.extend(chat_tokens)
if only_last_turn and (i != len(chat) - 1):
Expand Down Expand Up @@ -137,6 +154,7 @@ def encode(self, text):
not self.args.no_mask,
Encoder.tokenizer,
self.args.only_last,
self.args.for_rm,
)
ids[key] = (text_ids, label_ids)
return ids, len(text)
Expand All @@ -163,6 +181,11 @@ def get_args():
help="If set, this will not mask any tokens in the input data.",
action="store_true",
)
group.add_argument(
"--for-rm",
help="If set, this will mask everything except the last token in the chat.",
action="store_true",
)
group.add_argument(
"--generation-role",
type=str,
Expand Down

0 comments on commit 1c72742

Please sign in to comment.