From a55add25247e25ccde0a469191104e6bff992150 Mon Sep 17 00:00:00 2001 From: Almaz Dautov Date: Fri, 1 Nov 2024 18:15:27 +0300 Subject: [PATCH] in progress --- turbo_alignment/settings/pipelines/train/reinforce.py | 11 ++++++++++- turbo_alignment/trainers/online/reinforce.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py index be1c71a..8ae5a1b 100644 --- a/turbo_alignment/settings/pipelines/train/reinforce.py +++ b/turbo_alignment/settings/pipelines/train/reinforce.py @@ -6,7 +6,11 @@ ) from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings from turbo_alignment.settings.tf.trainer import TrainerSettings - +from turbo_alignment.settings.online import ( + CriticType, + LLMActorType, + RewardProcessorType, +) class REINFORCETrainerSettings(TrainerSettings): max_tokens_count: int = 1024 @@ -25,6 +29,11 @@ class REINFORCETrainerSettings(TrainerSettings): temperature: float | None = None whiten_rewards: bool = False + actor_type: LLMActorType = LLMActorType.LOCAL_TRANSFORMERS + critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS + + reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO + class REINFORCETrainExperimentSettings(BaseTrainExperimentSettings): train_dataset_settings: ChatMultiDatasetSettings diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 85c8e20..6e760db 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -226,6 +226,7 @@ def get_batch_loss_metrics( kl_term = logprobs.detach() - ref_logprobs regularized_rewards = rewards - self.kl_coef * kl_term + print(f"{regularized_rewards.shape=}", flush=True) baselined_reward, baseline_metrics = self.reward_processor.baseline_rewards(rewards=regularized_rewards) loss = -baselined_reward * logprobs