Skip to content

Commit

Permalink
feat(RLHF): update reward modelling code with new arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
rcmalli committed Jun 27, 2023
1 parent 7991be8 commit 353be25
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
33 changes: 32 additions & 1 deletion sheeprl/algos/rlhf/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,27 @@ def reward_loss_last_token(
return -F.logsigmoid(filtered_rewards).mean(), chosen_last_rewards, rejected_last_rewards


def reward_loss(
def reward_loss_average(
chosen: torch.Tensor,
rejected: torch.Tensor,
chosen_rewards: torch.Tensor,
rejected_rewards: torch.Tensor,
pad_token_id: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""This loss computes the logsigmoid of the difference between the chosen and rejected rewards from average of all output tokens excluding padding tokens"""
pad_mask_chosen = chosen != pad_token_id # (B, T)
pad_mask_rejected = rejected != pad_token_id # (B, T)

chosen_rewards_average = chosen_rewards * pad_mask_chosen
chosen_rewards_average = chosen_rewards_average.sum(dim=1) / pad_mask_chosen.sum(dim=1)
rejected_rewards_average = rejected_rewards * pad_mask_rejected
rejected_rewards_average = rejected_rewards_average.sum(dim=1) / pad_mask_rejected.sum(dim=1)

filtered_rewards = chosen_rewards_average - rejected_rewards_average
return -F.logsigmoid(filtered_rewards).mean(), chosen_rewards_average, rejected_rewards_average


def reward_loss_per_sample(
chosen: torch.Tensor,
rejected: torch.Tensor,
chosen_rewards: torch.Tensor,
Expand Down Expand Up @@ -96,6 +116,17 @@ def reward_loss(
return loss, chosen_last_rewards, rejected_last_rewards


def load_reward_loss(reward_loss_type: str):
if reward_loss_type == "average":
return reward_loss_average
elif reward_loss_type == "last_token":
return reward_loss_last_token
elif reward_loss_type == "per_sample":
return reward_loss_per_sample
else:
raise ValueError(f"Invalid reward loss type: {reward_loss_type}")


def finetune_loss(
outputs: torch.Tensor, targets: torch.Tensor, ignore_index: int = -100, label_smoothing: float = 0.0
) -> torch.Tensor:
Expand Down
32 changes: 12 additions & 20 deletions sheeprl/algos/rlhf/rlhf_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import time
from dataclasses import asdict
from typing import Callable

import lightning as L
import torch
Expand All @@ -18,13 +19,13 @@

from sheeprl.algos.rlhf.args import OPT, GenerationArgs, ModelArgs, RMArgs, TextDataArgs
from sheeprl.algos.rlhf.data import RMCollate
from sheeprl.algos.rlhf.lora_utils import add_lora
from sheeprl.algos.rlhf.loss import reward_loss
from sheeprl.algos.rlhf.loss import load_reward_loss
from sheeprl.algos.rlhf.models import CriticModel
from sheeprl.algos.rlhf.scheduler import CosineSchedulerWithWarmup
from sheeprl.algos.rlhf.utils import (
prepare_optimizer_parameters,
save_args_to_json,
setup_finetuning,
trainable_parameter_summary,
)

Expand All @@ -45,15 +46,15 @@ def accuracy(chosen_rewards: torch.Tensor, rejected_rewards: torch.Tensor):


@torch.inference_mode()
def evaluate(model: CriticModel, test_dataloader: DataLoader, pad_token_id: int, eval_iters: int):
def evaluate(model: CriticModel, test_dataloader: DataLoader, loss: Callable, pad_token_id: int, eval_iters: int):
model.eval()
eval_counter = 0
average_acc = 0
average_loss = 0
for batch in test_dataloader:
chosen_rewards = model(input_ids=batch["chosen_input_ids"], attention_mask=batch["chosen_attention_mask"])
rejected_rewards = model(input_ids=batch["rejected_input_ids"], attention_mask=batch["rejected_attention_mask"])
test_loss, choosen_last_rewards, rejected_last_rewards = reward_loss(
test_loss, choosen_last_rewards, rejected_last_rewards = loss(
chosen=batch["chosen_input_ids"],
rejected=batch["rejected_input_ids"],
chosen_rewards=chosen_rewards,
Expand Down Expand Up @@ -89,15 +90,6 @@ def main():
model_args = OPT()
gen_args = GenerationArgs()

# Customization for default arguments
train_args.learning_rate = 1e-5
train_args.epochs = 1
train_args.lr_warmup_steps = 100
train_args.micro_batch_size = 16
train_args.mini_batch_size = 16
train_args.eval_interval = 50
model_args.disable_dropout = False

data_args_path = Path(train_args.data_dir) / "args.json"
data_args = TextDataArgs.from_json(str(data_args_path))

Expand All @@ -122,6 +114,10 @@ def main():
# Log hyperparameters
fabric.logger.log_hyperparams(all_args)

# Setup Reward Loss

reward_loss = load_reward_loss(train_args.loss_type)

# Setup Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name)

Expand All @@ -134,14 +130,9 @@ def main():
with EmptyInitOnDevice(fabric.device):
model = CriticModel.from_checkpoint(device=fabric.device, model_args=model_args, freeze=True)
model = model.to(fabric.device)
if model_args.apply_lora:
add_lora(model, model_args)
else:
fabric.print("Not applying LORA, Finetuning all parameters")
for param in model.parameters():
param.requires_grad = True
setup_finetuning(fabric=fabric, model=model, model_args=model_args)

trainable_parameter_summary(fabric, model, show_names=False)
trainable_parameter_summary(model, show_names=False, fabric=fabric)

# Setup Generation Config
try:
Expand Down Expand Up @@ -232,6 +223,7 @@ def main():
test_loss, test_acc = evaluate(
model=model,
test_dataloader=test_dataloader,
loss=reward_loss,
pad_token_id=tokenizer.pad_token_id,
eval_iters=train_args.eval_iters,
)
Expand Down

0 comments on commit 353be25

Please sign in to comment.