Skip to content

Commit

Permalink
address #41 , be faithful to the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 23, 2023
1 parent d4faf48 commit 1487fd4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
23 changes: 14 additions & 9 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from palm_rlhf_pytorch.palm import PaLM
Expand Down Expand Up @@ -263,16 +263,17 @@ def masked_entropy(prob, dim = -1, mask = None):
entropies = (prob * log(prob)).sum(dim = -1)
return masked_mean(entropies, mask = mask).mean()

def masked_kl_div(prob1, prob2, mask = None):
def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False):
"""
need to account for variable sequence lengths, therefore not using the built-in functional version
"""
kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1)
loss = masked_mean(kl_divs, mask)

if not exists(mask):
return kl_divs.mean()
if reduce_batch:
return loss.mean()

return masked_mean(kl_divs, mask).mean()
return loss

def clipped_value_loss(values, rewards, old_values, clip):
value_clipped = old_values + (values - old_values).clamp(-clip, clip)
Expand Down Expand Up @@ -502,10 +503,14 @@ def learn(

# calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not

kl_div_loss = 0.
kl_penalty = 0.

if self.kl_div_loss_weight > 0:
kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight
kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight

# subtract the kl penalty from the rewards

rewards = rewards - kl_penalty

# handle non-pooled values

Expand Down Expand Up @@ -536,7 +541,7 @@ def learn(

# combine losses

loss = policy_loss.mean() + kl_div_loss
loss = policy_loss.mean()

# update actor

Expand All @@ -552,7 +557,7 @@ def learn(

# calculate value loss and update value network separate from policy network

value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)
value_loss = clipped_value_loss(values, rewards.detach(), old_values, self.value_clip)
value_loss = value_loss.mean()

self.print(f'critic_loss: {value_loss.item():.3f}')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.4',
version = '0.2.0',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 1487fd4

Please sign in to comment.