Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

readded RM training removed during merge conflict in KTO #1295

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,13 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes.
"""

z_loss: float = 0.0
"""
Z-loss parameter, only implemented for RM training currently.
https://arxiv.org/pdf/2204.02311
https://arxiv.org/pdf/2309.10305
"""

kto_beta: float = 0.1
"""
Beta value for KTO
Expand Down
34 changes: 30 additions & 4 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
base_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand All @@ -111,7 +111,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down Expand Up @@ -324,7 +324,7 @@ def get_batch(neox_args, data_iterator):
# Items and their type.
if neox_args.train_impl in ["normal", "kto"]:
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl == "dpo":
elif neox_args.train_impl in ["dpo", "rm"]:
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_train_label_data_paths
Expand Down Expand Up @@ -575,6 +575,32 @@ def forward_step(
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_impl == "rm":
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, _ = maybe_tuple
else:
outputs = maybe_tuple
pos, neg = torch.chunk(outputs, 2, 0)
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
# We assume that each pos, neg pair occur in the same order
# e.g. second nonzero pos is the corresponding second nonzero neg
# and that there are also an equal number of pos and neg in each sequence.
pos_indx = pos_loss_mask.nonzero()
neg_indx = neg_loss_mask.nonzero()
# indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index.
pos_indx = pos_indx[:, 1].unsqueeze(1)
neg_indx = neg_indx[:, 1].unsqueeze(1)
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx)
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx)
with torch.no_grad():
metrics["pos_values"] = pos.clone().detach().mean()
metrics["neg_values"] = neg.clone().detach().mean()
metrics["margin"] = (pos - neg).clone().detach().mean()
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean()
loss = (-F.logsigmoid(pos - neg).mean()) + (
(neox_args.z_loss * (pos**2 + neg**2)).mean()
)
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.inference_mode():
Expand Down Expand Up @@ -786,7 +812,7 @@ def get_model(neox_args, use_cache=False):
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down
Loading