diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 12b81e202..1b6909c9f 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -393,6 +393,7 @@ def load_checkpoint( load_lr_scheduler_states=load_optim_and_scheduler, load_module_only=not load_optim_and_scheduler, tag=tag, + load_module_strict=neox_args.train_impl != "rm", ) if checkpoint_name is None: diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index f14076a17..523cbe4cf 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -199,7 +199,8 @@ def __init__( is_last_layer=False, ): super().__init__() - parallelism = neox_args.output_layer_parallelism + self.is_rm = neox_args.train_impl == "rm" + parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row" if parallelism == "column": self.final_linear = mpu.ColumnParallelLinear( neox_args=neox_args, @@ -212,26 +213,41 @@ def __init__( mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 ) - - # else: - # print( - # 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' - # ) - # exit() - # self.final_linear = mpu.RowParallelLinear( - # neox_args=neox_args, - # input_size=neox_args.hidden_size, - # output_size=neox_args.padded_vocab_size, - # bias=False, - # input_is_parallel=False, - # init_method=init_method, - # parallel_output=parallel_output, - # skip_bias_add=False, - # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here - # ) + else: + if not self.is_rm: + print( + 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' + ) + exit() + # self.final_linear = mpu.RowParallelLinear( + # neox_args=neox_args, + # input_size=neox_args.hidden_size, + # output_size=neox_args.padded_vocab_size, + # bias=False, + # input_is_parallel=False, + # init_method=init_method, + # parallel_output=parallel_output, + # skip_bias_add=False, + # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + # ) + else: # Not using cross entropy loss for RMs + self.rm_linear = mpu.RowParallelLinear( + neox_args=neox_args, + input_size=neox_args.hidden_size, + output_size=1, + bias=False, + input_is_parallel=False, + init_method=init_method, + parallel_output=False, + skip_bias_add=False, + mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + ) def forward(self, hidden_states): - return self.final_linear(hidden_states) + if not self.is_rm: + return self.final_linear(hidden_states) + else: + return self.rm_linear(hidden_states) class _MegablocksAdapter(nn.Module): diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index b5e7a619d..a87a573ae 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -997,9 +997,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo"] = "normal" + train_impl: Literal["normal", "dpo", "rm"] = "normal" """ - Training implementation, can be one of "normal" or "dpo" + Training implementation, can be one of "normal", "dpo", or "rm" """ dpo_fp32: bool = True @@ -1012,6 +1012,13 @@ class NeoXArgsTraining(NeoXArgsTemplate): Beta value for DPO """ + z_loss: float = 0.0 + """ + Z-loss parameter, only implemented for RM training currently. + https://arxiv.org/pdf/2204.02311 + https://arxiv.org/pdf/2309.10305 + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. diff --git a/megatron/training.py b/megatron/training.py index d9932483a..31e5c2bce 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -317,7 +317,7 @@ def get_batch(neox_args, data_iterator): # Items and their type. if neox_args.train_impl == "normal": keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] - elif neox_args.train_impl == "dpo": + elif neox_args.train_impl in ["dpo", "rm"]: keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] if neox_args.pos_train_label_data_paths @@ -338,7 +338,7 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - elif neox_args.train_impl == "dpo": + elif neox_args.train_impl in ["dpo", "rm"]: pos_tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -353,7 +353,7 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - if neox_args.precompute_model_name: + if (neox_args.precompute_model_name) and (neox_args.train_impl == "dpo"): ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) else: ref_data = {"pos_ref": None} @@ -491,7 +491,7 @@ def forward_step( tokens, labels, loss_mask, attention_mask, position_ids = get_batch( neox_args=neox_args, data_iterator=data_iterator ) - if neox_args.train_impl == "dpo": + if neox_args.train_impl in ["dpo", "rm"]: tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( neox_args=neox_args, data_iterator=data_iterator ) @@ -532,6 +532,32 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss + elif neox_args.train_impl == "rm": + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, _ = maybe_tuple + else: + outputs = maybe_tuple + pos, neg = torch.chunk(outputs, 2, 0) + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + # We assume that each pos, neg pair occur in the same order + # e.g. second nonzero pos is the corresponding second nonzero neg + # and that there are also an equal number of pos and neg in each sequence. + pos_indx = pos_loss_mask.nonzero() + neg_indx = neg_loss_mask.nonzero() + # indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index. + pos_indx = pos_indx[:, 1].unsqueeze(1) + neg_indx = neg_indx[:, 1].unsqueeze(1) + pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx) + neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx) + with torch.no_grad(): + metrics["pos_values"] = pos.clone().detach().mean() + metrics["neg_values"] = neg.clone().detach().mean() + metrics["margin"] = (pos - neg).clone().detach().mean() + metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean() + loss = (-F.logsigmoid(pos - neg).mean()) + ( + (neox_args.z_loss * (pos**2 + neg**2)).mean() + ) elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.no_grad(): @@ -616,7 +642,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, ) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 4e101ea5a..4d058127c 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -81,6 +81,7 @@ def build_chat( apply_mask: bool, tokenizer: PreTrainedTokenizer, only_last_turn: bool = False, + for_rm: bool = False, ) -> Tuple[List[int], List[int]]: """ Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the @@ -91,12 +92,28 @@ def build_chat( :param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss :param tokenizer: A HF tokenizer :param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks + :param for_rm: Whether this is for a reward model or not, this will mask everything except EOS token. + If you need a more complicated setup, you can modify this function to suit your needs. """ tokens = [] mask = [] if apply_mask is False: tokens = tokenizer.apply_chat_template(chat) mask = tokens + if tokenizer.eos_token_id is not None: + mask.append(tokenizer.eos_token_id) + tokens.append(tokenizer.eos_token_id) + return tokens, mask + elif for_rm: + tokens = tokenizer.apply_chat_template(chat) + mask = [-100] * len(tokens) + if tokenizer.eos_token_id is not None: + mask.append(tokenizer.eos_token_id) + tokens.append(tokenizer.eos_token_id) + else: + raise ValueError( + "Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own." + ) return tokens, mask for i, turn in enumerate(chat): add_gen = ( @@ -105,7 +122,7 @@ def build_chat( chat_tokens = tokenizer.apply_chat_template( chat[: i + 1], add_generation_prompt=add_gen )[len(tokens) :] - + # remove previous stuff... tokens.extend(chat_tokens) if only_last_turn and (i != len(chat) - 1): @@ -137,6 +154,7 @@ def encode(self, text): not self.args.no_mask, Encoder.tokenizer, self.args.only_last, + self.args.for_rm, ) ids[key] = (text_ids, label_ids) return ids, len(text) @@ -163,6 +181,11 @@ def get_args(): help="If set, this will not mask any tokens in the input data.", action="store_true", ) + group.add_argument( + "--for-rm", + help="If set, this will mask everything except the last token in the chat.", + action="store_true", + ) group.add_argument( "--generation-role", type=str,