From b0b490dea75cca93d7c89f0e5b18dadb006704dd Mon Sep 17 00:00:00 2001 From: dmahan93 <44207705+dmahan93@users.noreply.github.com> Date: Thu, 26 Sep 2024 13:18:46 -0500 Subject: [PATCH] readded RM training removed during merge conflict in KTO (#1295) * readded RM training removed during merge conflict in KTO * - parallel output updated --- megatron/neox_arguments/neox_args.py | 7 ++++++ megatron/training.py | 34 ++++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index df7e51da6..0dbdc8be0 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1087,6 +1087,13 @@ class NeoXArgsTraining(NeoXArgsTemplate): Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. """ + z_loss: float = 0.0 + """ + Z-loss parameter, only implemented for RM training currently. + https://arxiv.org/pdf/2204.02311 + https://arxiv.org/pdf/2309.10305 + """ + kto_beta: float = 0.1 """ Beta value for KTO diff --git a/megatron/training.py b/megatron/training.py index cca3ab3f9..72cc551a8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -87,7 +87,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache): base_model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, ) @@ -111,7 +111,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache): delta_model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, ) @@ -324,7 +324,7 @@ def get_batch(neox_args, data_iterator): # Items and their type. if neox_args.train_impl in ["normal", "kto"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] - elif neox_args.train_impl == "dpo": + elif neox_args.train_impl in ["dpo", "rm"]: keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] if neox_args.pos_train_label_data_paths @@ -575,6 +575,32 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss + elif neox_args.train_impl == "rm": + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, _ = maybe_tuple + else: + outputs = maybe_tuple + pos, neg = torch.chunk(outputs, 2, 0) + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + # We assume that each pos, neg pair occur in the same order + # e.g. second nonzero pos is the corresponding second nonzero neg + # and that there are also an equal number of pos and neg in each sequence. + pos_indx = pos_loss_mask.nonzero() + neg_indx = neg_loss_mask.nonzero() + # indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index. + pos_indx = pos_indx[:, 1].unsqueeze(1) + neg_indx = neg_indx[:, 1].unsqueeze(1) + pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx) + neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx) + with torch.no_grad(): + metrics["pos_values"] = pos.clone().detach().mean() + metrics["neg_values"] = neg.clone().detach().mean() + metrics["margin"] = (pos - neg).clone().detach().mean() + metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean() + loss = (-F.logsigmoid(pos - neg).mean()) + ( + (neox_args.z_loss * (pos**2 + neg**2)).mean() + ) elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.inference_mode(): @@ -786,7 +812,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, )