Skip to content

Commit

Permalink
fix rewards mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
d.taranets committed Nov 27, 2024
1 parent 4ef4f87 commit b868f66
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions turbo_alignment/generators/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ def _generate_from_batch(

rewards = torch.cat(rewards, dim=0)

record_rewards = [
{answer_id: reward.item() for answer_id, reward in zip(record['answers'].keys(), rewards)}
for record in records
]
reward_index = 0
record_rewards = []
for record in records:
mapped_rewards = {}
for key in record['answers'].keys():
mapped_rewards[key] = rewards[reward_index].item()
reward_index += 1
record_rewards.append(mapped_rewards)

return [
RMSamplingInferenceOutput(
Expand Down

0 comments on commit b868f66

Please sign in to comment.