Skip to content

Commit

Permalink
readded RM training removed during merge conflict in KTO (#1295)
Browse files Browse the repository at this point in the history
* readded RM training removed during merge conflict in KTO

* - parallel output updated
  • Loading branch information
dmahan93 authored Sep 26, 2024
1 parent 020ce55 commit b0b490d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
7 changes: 7 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,13 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes.
"""

z_loss: float = 0.0
"""
Z-loss parameter, only implemented for RM training currently.
https://arxiv.org/pdf/2204.02311
https://arxiv.org/pdf/2309.10305
"""

kto_beta: float = 0.1
"""
Beta value for KTO
Expand Down
34 changes: 30 additions & 4 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
base_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand All @@ -111,7 +111,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down Expand Up @@ -324,7 +324,7 @@ def get_batch(neox_args, data_iterator):
# Items and their type.
if neox_args.train_impl in ["normal", "kto"]:
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl == "dpo":
elif neox_args.train_impl in ["dpo", "rm"]:
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_train_label_data_paths
Expand Down Expand Up @@ -575,6 +575,32 @@ def forward_step(
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_impl == "rm":
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, _ = maybe_tuple
else:
outputs = maybe_tuple
pos, neg = torch.chunk(outputs, 2, 0)
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
# We assume that each pos, neg pair occur in the same order
# e.g. second nonzero pos is the corresponding second nonzero neg
# and that there are also an equal number of pos and neg in each sequence.
pos_indx = pos_loss_mask.nonzero()
neg_indx = neg_loss_mask.nonzero()
# indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index.
pos_indx = pos_indx[:, 1].unsqueeze(1)
neg_indx = neg_indx[:, 1].unsqueeze(1)
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx)
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx)
with torch.no_grad():
metrics["pos_values"] = pos.clone().detach().mean()
metrics["neg_values"] = neg.clone().detach().mean()
metrics["margin"] = (pos - neg).clone().detach().mean()
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean()
loss = (-F.logsigmoid(pos - neg).mean()) + (
(neox_args.z_loss * (pos**2 + neg**2)).mean()
)
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.inference_mode():
Expand Down Expand Up @@ -786,7 +812,7 @@ def get_model(neox_args, use_cache=False):
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down

0 comments on commit b0b490d

Please sign in to comment.