Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Nov 1, 2024
1 parent ca20098 commit a55add2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
11 changes: 10 additions & 1 deletion turbo_alignment/settings/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
)
from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings
from turbo_alignment.settings.tf.trainer import TrainerSettings

from turbo_alignment.settings.online import (
CriticType,
LLMActorType,
RewardProcessorType,
)

class REINFORCETrainerSettings(TrainerSettings):
max_tokens_count: int = 1024
Expand All @@ -25,6 +29,11 @@ class REINFORCETrainerSettings(TrainerSettings):
temperature: float | None = None
whiten_rewards: bool = False

actor_type: LLMActorType = LLMActorType.LOCAL_TRANSFORMERS
critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS

reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO


class REINFORCETrainExperimentSettings(BaseTrainExperimentSettings):
train_dataset_settings: ChatMultiDatasetSettings
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def get_batch_loss_metrics(
kl_term = logprobs.detach() - ref_logprobs
regularized_rewards = rewards - self.kl_coef * kl_term

print(f"{regularized_rewards.shape=}", flush=True)
baselined_reward, baseline_metrics = self.reward_processor.baseline_rewards(rewards=regularized_rewards)

loss = -baselined_reward * logprobs
Expand Down

0 comments on commit a55add2

Please sign in to comment.