Skip to content

Commit

Permalink
fix token_level_kl
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 16, 2025
1 parent 57d9a97 commit 475a157
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ class RLOOConfig(OnPolicyConfig):
metadata={"help": "Whether to normalize advantages"},
)
token_level_kl: bool = field(
default=True,
default=False,
metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"},
)
11 changes: 7 additions & 4 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,16 @@ def repeat_generator():
# Compute total reward with KL penalty
if args.token_level_kl:
# Token-level KL penalty: apply KL penalty per token
token_kl_penalty = -args.kl_coef * kl
non_score_reward = token_kl_penalty.sum(1)
kl_reward = -args.kl_coef * kl
# Apply reward at the last non-padded token position for each sequence
eos_indices = sequence_lengths.unsqueeze(1) - 1
last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=scores.unsqueeze(1))
non_score_reward = kl_reward.sum(1) + last_reward.sum(1)
else:
# Sequence-level KL penalty: sum KL across tokens first
sequence_kl = kl.sum(1)
non_score_reward = -args.kl_coef * sequence_kl
rlhf_reward = scores + non_score_reward
non_score_reward = -args.kl_coef * sequence_kl + scores
rlhf_reward = non_score_reward

# vectorized RLOO advantages implementation
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
Expand Down

0 comments on commit 475a157

Please sign in to comment.