From b868f6664c8f714b3b136db780ad7d30db2a88b4 Mon Sep 17 00:00:00 2001 From: "d.taranets" Date: Wed, 27 Nov 2024 16:38:36 +0300 Subject: [PATCH] fix rewards mapping --- turbo_alignment/generators/rm.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py index ce9a9c8..55cdf90 100755 --- a/turbo_alignment/generators/rm.py +++ b/turbo_alignment/generators/rm.py @@ -78,10 +78,14 @@ def _generate_from_batch( rewards = torch.cat(rewards, dim=0) - record_rewards = [ - {answer_id: reward.item() for answer_id, reward in zip(record['answers'].keys(), rewards)} - for record in records - ] + reward_index = 0 + record_rewards = [] + for record in records: + mapped_rewards = {} + for key in record['answers'].keys(): + mapped_rewards[key] = rewards[reward_index].item() + reward_index += 1 + record_rewards.append(mapped_rewards) return [ RMSamplingInferenceOutput(