Skip to content

Commit

Permalink
[RLOO] Reinforce++ (huggingface#2552)
Browse files Browse the repository at this point in the history
* Reinforce++

* formatting

* fix link
  • Loading branch information
kashif authored Jan 9, 2025
1 parent abfffc5 commit edabe0a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
11 changes: 11 additions & 0 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ python -m openrlbenchmark.rlops_multi_metrics \
--scan-history
```

## Reinforce++

The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jian Hu suggests several optimization tricks to enhance performance and stability of RLHF. They include:

- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion
- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
- Using token-level KL penalty (default) vs. sequence-level KL penalty

These options are available via the appropriate arguments in the [`RLOOConfig`] class.


## RLOOTrainer

Expand Down
24 changes: 24 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ class RLOOConfig(OnPolicyConfig):
Clip range.
rloo_k (`int`, *optional*, defaults to `2`):
REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
normalize_reward (`bool`, *optional*, defaults to `False`):
Whether to normalize rewards.
reward_clip_range (`float`, *optional*, defaults to `10.0`):
Clip range for rewards.
normalize_advantage (`bool`, *optional*, defaults to `False`):
Whether to normalize advantages.
token_level_kl (`bool`, *optional*, defaults to `True`):
Whether to use token-level KL penalty or sequence-level KL penalty.
"""

exp_name: str = field(
Expand Down Expand Up @@ -72,3 +80,19 @@ class RLOOConfig(OnPolicyConfig):
default=2,
metadata={"help": "REINFORCE Leave-One-Out (RLOO) number of online samples per prompt."},
)
normalize_reward: bool = field(
default=False,
metadata={"help": "Whether to normalize rewards"},
)
reward_clip_range: float = field(
default=10.0,
metadata={"help": "Clip range for rewards"},
)
normalize_advantage: bool = field(
default=False,
metadata={"help": "Whether to normalize advantages"},
)
token_level_kl: bool = field(
default=True,
metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"},
)
45 changes: 44 additions & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def repeat_generator():
ref_logprobs = []
scores = []
sequence_lengths = []

# Generate responses and compute logprobs
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model,
Expand All @@ -317,6 +319,7 @@ def repeat_generator():
generation_config,
)

# Process responses in batches
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
Expand Down Expand Up @@ -349,12 +352,15 @@ def repeat_generator():
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)

# Store batch results
responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)

# Concatenate all batched results
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
logprobs = torch.cat(logprobs, 0)
Expand All @@ -380,15 +386,35 @@ def repeat_generator():
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)

# 4. compute rewards
# Compute KL divergence
kl = logprobs - ref_logprobs
non_score_reward = (-args.kl_coef * kl).sum(1)

# Normalize rewards
if args.normalize_reward:
scores = (scores - scores.mean()) / (scores.std() + 1e-8)
scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)

# Compute total reward with KL penalty
if args.token_level_kl:
# Token-level KL penalty: apply KL penalty per token
token_kl_penalty = -args.kl_coef * kl
non_score_reward = token_kl_penalty.sum(1)
else:
# Sequence-level KL penalty: sum KL across tokens first
sequence_kl = kl.sum(1)
non_score_reward = -args.kl_coef * sequence_kl
rlhf_reward = scores + non_score_reward

# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
advantages = rlhf_reward - baseline
advantages = advantages.flatten()

# Normalize advantages
if args.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

torch.cuda.empty_cache()

# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
Expand All @@ -403,32 +429,46 @@ def repeat_generator():
with accelerator.accumulate(model):
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]

# Get batch data
mb_advantage = advantages[micro_batch_inds]
mb_responses = responses[micro_batch_inds]
mb_query_responses = query_responses[micro_batch_inds]
mb_logprobs = logprobs[micro_batch_inds]

# Forward pass
output = forward(model, mb_query_responses, processing_class.pad_token_id)
logits = output.logits[:, context_length - 1 : -1]
logits /= args.temperature + 1e-7

# Compute new logprobs
new_all_logprobs = F.log_softmax(logits, dim=-1)
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
new_logprobs = torch.masked_fill(
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
)

# Compute probability ratios
new_ratio = (new_logprobs - mb_logprobs).exp()
new_logprobs = new_logprobs.sum(1)
mb_logprobs = mb_logprobs.sum(1)
logprobs_diff = new_logprobs - mb_logprobs
ratio = torch.exp(logprobs_diff)

# PPO clipped loss
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
pg_loss_max = torch.max(pg_losses, pg_losses2)
pg_loss = pg_loss_max.mean()

# Final loss
loss = pg_loss

# Optimization step
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()

with torch.no_grad():
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
Expand All @@ -443,6 +483,7 @@ def repeat_generator():
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1

# del everything and empty cache
# fmt: off
del (
Expand All @@ -453,6 +494,8 @@ def repeat_generator():
)
# fmt: on
torch.cuda.empty_cache()

# Compute metrics
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
Expand Down

0 comments on commit edabe0a

Please sign in to comment.