From 353be2512ac0edf5a49fdef55edeb956fc582885 Mon Sep 17 00:00:00 2001 From: rcmalli Date: Tue, 27 Jun 2023 08:35:17 +0000 Subject: [PATCH] feat(RLHF): update reward modelling code with new arguments --- sheeprl/algos/rlhf/loss.py | 33 ++++++++++++++++++++++++++++++++- sheeprl/algos/rlhf/rlhf_rm.py | 32 ++++++++++++-------------------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/sheeprl/algos/rlhf/loss.py b/sheeprl/algos/rlhf/loss.py index 80ef9a19..a6213314 100644 --- a/sheeprl/algos/rlhf/loss.py +++ b/sheeprl/algos/rlhf/loss.py @@ -23,7 +23,27 @@ def reward_loss_last_token( return -F.logsigmoid(filtered_rewards).mean(), chosen_last_rewards, rejected_last_rewards -def reward_loss( +def reward_loss_average( + chosen: torch.Tensor, + rejected: torch.Tensor, + chosen_rewards: torch.Tensor, + rejected_rewards: torch.Tensor, + pad_token_id: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """This loss computes the logsigmoid of the difference between the chosen and rejected rewards from average of all output tokens excluding padding tokens""" + pad_mask_chosen = chosen != pad_token_id # (B, T) + pad_mask_rejected = rejected != pad_token_id # (B, T) + + chosen_rewards_average = chosen_rewards * pad_mask_chosen + chosen_rewards_average = chosen_rewards_average.sum(dim=1) / pad_mask_chosen.sum(dim=1) + rejected_rewards_average = rejected_rewards * pad_mask_rejected + rejected_rewards_average = rejected_rewards_average.sum(dim=1) / pad_mask_rejected.sum(dim=1) + + filtered_rewards = chosen_rewards_average - rejected_rewards_average + return -F.logsigmoid(filtered_rewards).mean(), chosen_rewards_average, rejected_rewards_average + + +def reward_loss_per_sample( chosen: torch.Tensor, rejected: torch.Tensor, chosen_rewards: torch.Tensor, @@ -96,6 +116,17 @@ def reward_loss( return loss, chosen_last_rewards, rejected_last_rewards +def load_reward_loss(reward_loss_type: str): + if reward_loss_type == "average": + return reward_loss_average + elif reward_loss_type == "last_token": + return reward_loss_last_token + elif reward_loss_type == "per_sample": + return reward_loss_per_sample + else: + raise ValueError(f"Invalid reward loss type: {reward_loss_type}") + + def finetune_loss( outputs: torch.Tensor, targets: torch.Tensor, ignore_index: int = -100, label_smoothing: float = 0.0 ) -> torch.Tensor: diff --git a/sheeprl/algos/rlhf/rlhf_rm.py b/sheeprl/algos/rlhf/rlhf_rm.py index c80202f9..21f615d6 100644 --- a/sheeprl/algos/rlhf/rlhf_rm.py +++ b/sheeprl/algos/rlhf/rlhf_rm.py @@ -3,6 +3,7 @@ import sys import time from dataclasses import asdict +from typing import Callable import lightning as L import torch @@ -18,13 +19,13 @@ from sheeprl.algos.rlhf.args import OPT, GenerationArgs, ModelArgs, RMArgs, TextDataArgs from sheeprl.algos.rlhf.data import RMCollate -from sheeprl.algos.rlhf.lora_utils import add_lora -from sheeprl.algos.rlhf.loss import reward_loss +from sheeprl.algos.rlhf.loss import load_reward_loss from sheeprl.algos.rlhf.models import CriticModel from sheeprl.algos.rlhf.scheduler import CosineSchedulerWithWarmup from sheeprl.algos.rlhf.utils import ( prepare_optimizer_parameters, save_args_to_json, + setup_finetuning, trainable_parameter_summary, ) @@ -45,7 +46,7 @@ def accuracy(chosen_rewards: torch.Tensor, rejected_rewards: torch.Tensor): @torch.inference_mode() -def evaluate(model: CriticModel, test_dataloader: DataLoader, pad_token_id: int, eval_iters: int): +def evaluate(model: CriticModel, test_dataloader: DataLoader, loss: Callable, pad_token_id: int, eval_iters: int): model.eval() eval_counter = 0 average_acc = 0 @@ -53,7 +54,7 @@ def evaluate(model: CriticModel, test_dataloader: DataLoader, pad_token_id: int, for batch in test_dataloader: chosen_rewards = model(input_ids=batch["chosen_input_ids"], attention_mask=batch["chosen_attention_mask"]) rejected_rewards = model(input_ids=batch["rejected_input_ids"], attention_mask=batch["rejected_attention_mask"]) - test_loss, choosen_last_rewards, rejected_last_rewards = reward_loss( + test_loss, choosen_last_rewards, rejected_last_rewards = loss( chosen=batch["chosen_input_ids"], rejected=batch["rejected_input_ids"], chosen_rewards=chosen_rewards, @@ -89,15 +90,6 @@ def main(): model_args = OPT() gen_args = GenerationArgs() - # Customization for default arguments - train_args.learning_rate = 1e-5 - train_args.epochs = 1 - train_args.lr_warmup_steps = 100 - train_args.micro_batch_size = 16 - train_args.mini_batch_size = 16 - train_args.eval_interval = 50 - model_args.disable_dropout = False - data_args_path = Path(train_args.data_dir) / "args.json" data_args = TextDataArgs.from_json(str(data_args_path)) @@ -122,6 +114,10 @@ def main(): # Log hyperparameters fabric.logger.log_hyperparams(all_args) + # Setup Reward Loss + + reward_loss = load_reward_loss(train_args.loss_type) + # Setup Tokenizer tokenizer = AutoTokenizer.from_pretrained(model_args.model_name) @@ -134,14 +130,9 @@ def main(): with EmptyInitOnDevice(fabric.device): model = CriticModel.from_checkpoint(device=fabric.device, model_args=model_args, freeze=True) model = model.to(fabric.device) - if model_args.apply_lora: - add_lora(model, model_args) - else: - fabric.print("Not applying LORA, Finetuning all parameters") - for param in model.parameters(): - param.requires_grad = True + setup_finetuning(fabric=fabric, model=model, model_args=model_args) - trainable_parameter_summary(fabric, model, show_names=False) + trainable_parameter_summary(model, show_names=False, fabric=fabric) # Setup Generation Config try: @@ -232,6 +223,7 @@ def main(): test_loss, test_acc = evaluate( model=model, test_dataloader=test_dataloader, + loss=reward_loss, pad_token_id=tokenizer.pad_token_id, eval_iters=train_args.eval_iters, )