From f5a7d1d30fbe12f3d2796726591583f65ef671ce Mon Sep 17 00:00:00 2001 From: "n.surnachev" Date: Sat, 19 Oct 2024 22:51:54 +0300 Subject: [PATCH] in progress --- turbo_alignment/common/distributed.py | 1 - turbo_alignment/trainers/online/reinforce.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/turbo_alignment/common/distributed.py b/turbo_alignment/common/distributed.py index 07aa695..a81c3cf 100644 --- a/turbo_alignment/common/distributed.py +++ b/turbo_alignment/common/distributed.py @@ -23,7 +23,6 @@ def get_global_mean(values: torch.Tensor) -> float: # Calculate the mean reward for the current process local_sum = values.sum().item() - print("WORLD SIZE 😼: ", world_size) if world_size == 1: return values.mean().item() diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 5c014a6..85c8e20 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -343,11 +343,6 @@ def get_logprobs( def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]: if self.args.non_eos_penalty: - assert torch.all(query_response[:, -1] != self.tokenizer.pad_token_id), ( - query_response[:, -1], - self.tokenizer.pad_token_id, - ) - invalid_mask = query_response[:, -1] != self.stop_generation_token_id[0] rewards[invalid_mask] = self.args.penalty_reward_value