From c0c3e14225e41afabefe8c9722cc97b9085f9fab Mon Sep 17 00:00:00 2001 From: Almaz Dautov Date: Tue, 24 Dec 2024 16:34:54 +0300 Subject: [PATCH] rm replicas added --- turbo_alignment/cli/train.py | 4 +++- turbo_alignment/common/logging/clearml.py | 7 +++++-- turbo_alignment/common/tf/special_tokens_setter.py | 4 ++-- turbo_alignment/dataset/chat/chat.py | 4 ++-- turbo_alignment/generators/rm.py | 8 ++++++-- turbo_alignment/settings/pipelines/train/reinforce.py | 1 + turbo_alignment/trainers/online/ray/rayactor_group.py | 5 +++++ turbo_alignment/trainers/online/reinforce.py | 9 +++++++-- 8 files changed, 31 insertions(+), 11 deletions(-) diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index af38656..0ce42ab 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -131,7 +131,9 @@ def reinforce_training( experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path) policy_models = RayGroup(num_nodes=experiment_settings.trainer_settings.num_nodes, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy) - reward_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=RewardModel) + + assert 1 <= experiment_settings.trainer_settings.reward_model_replicas <= 8 + reward_model = RayGroup(num_nodes=1, num_gpus_per_node=experiment_settings.trainer_settings.reward_model_replicas, ray_actor_type=RewardModel) # reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel) # TODO_RLOO if possible hide init inside RayGroup diff --git a/turbo_alignment/common/logging/clearml.py b/turbo_alignment/common/logging/clearml.py index c429439..3422f56 100644 --- a/turbo_alignment/common/logging/clearml.py +++ b/turbo_alignment/common/logging/clearml.py @@ -7,9 +7,12 @@ def create_clearml_task(parameters: ClearMLSettings, config: dict[str, Any] | None = None) -> Task: clearml_task = Task.init( - task_name=parameters.task_name, project_name=parameters.project_name, continue_last_task=True # FIXME? + task_name=parameters.task_name, + project_name=parameters.project_name, + continue_last_task=True, + output_uri=False, ) clearml_task.connect_configuration(config, name='HyperParameters') - return clearml_task + return clearml_task \ No newline at end of file diff --git a/turbo_alignment/common/tf/special_tokens_setter.py b/turbo_alignment/common/tf/special_tokens_setter.py index 5fb4353..b1aaeb6 100755 --- a/turbo_alignment/common/tf/special_tokens_setter.py +++ b/turbo_alignment/common/tf/special_tokens_setter.py @@ -77,7 +77,7 @@ def setSEP(self, sep_token: str | None) -> None: return None def set_all(self) -> None: - # self.setBOS(bos_token=self._special_tokens_settings.bos_token) + self.setBOS(bos_token=self._special_tokens_settings.bos_token) self.setEOS(eos_token=self._special_tokens_settings.eos_token) self.setPAD(pad_token=self._special_tokens_settings.pad_token) self.setUNK(unk_token=self._special_tokens_settings.unk_token) @@ -94,7 +94,7 @@ def set_custom_tokens(self, tokens: list[str]) -> None: assert added_tokens == len(tokens) def setup_model_config(self, model: PreTrainedModel) -> None: - # model.config.bos_token_id = self._tokenizer.bos_token_id + model.config.bos_token_id = self._tokenizer.bos_token_id model.config.eos_token_id = self._tokenizer.eos_token_id if self._tokenizer.pad_token_id is not None: diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index 48309c0..ebdede9 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -247,8 +247,8 @@ def _truncate_and_merge( input_ids = input_ids[cut_index:] labels = labels[-len(input_ids) :] - # input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids)) - # labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels)) + input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids)) + labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels)) return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound) diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py index 7ff069d..9d749d8 100755 --- a/turbo_alignment/generators/rm.py +++ b/turbo_alignment/generators/rm.py @@ -9,16 +9,20 @@ import ray class RayRMSamplingGenerator(BaseGenerator[SamplingDatasetRecord, RMSamplingInferenceOutput]): - def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, **kwargs): + def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, model_replicas: int = 1, **kwargs): self._collator = DataCollatorWithPadding(tokenizer=tokenizer) self._micro_batch = micro_batch + self.model_replicas = model_replicas super().__init__(tokenizer=tokenizer, **kwargs) def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEncoding) -> torch.Tensor: with torch.no_grad(): # TODO assuming that only one reward model records = {k: v.cuda() for k, v in records.items()} - rewards = ray.get(self._model.async_forward(records))[0] + + rank = torch.distributed.get_rank() + + rewards = ray.get(self._model.reward_forward(records=records, index=rank % self.model_replicas)) return rewards # .squeeze() def generate_from_batch( diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py index c530731..7eb5170 100644 --- a/turbo_alignment/settings/pipelines/train/reinforce.py +++ b/turbo_alignment/settings/pipelines/train/reinforce.py @@ -18,6 +18,7 @@ class REINFORCETrainerSettings(TrainerSettings): num_nodes: int = 2 + reward_model_replicas: int = 1 max_new_tokens: int = 1024 stop_token: str = '' diff --git a/turbo_alignment/trainers/online/ray/rayactor_group.py b/turbo_alignment/trainers/online/ray/rayactor_group.py index 80e16bf..d9911d2 100644 --- a/turbo_alignment/trainers/online/ray/rayactor_group.py +++ b/turbo_alignment/trainers/online/ray/rayactor_group.py @@ -104,6 +104,11 @@ def reference_forward(self, records: dict[str, torch.Tensor], temperature, loss_ # TODO assuming one reference model return self._actor_handlers[0].reference_forward.remote(records, temperature, loss_mask) + def reward_forward(self, records: dict[str, torch.Tensor], index: int): + assert index < len(self._actor_handlers), f'Reward model replicas: {len(self._actor_handlers)}, provided index: {index}' + + return self._actor_handlers[index].forward.remote(records) + def async_forward(self, records: dict[str, torch.Tensor]): return [actor.forward.remote(records) for actor in self._actor_handlers] diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 5f62940..d413029 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -83,6 +83,7 @@ def mean_std_format(values): @dataclass class REINFORCETrainingArguments(TrainingArguments): num_nodes: int = 2 + reward_model_replicas: int = 1 max_new_tokens: int = 1024 stop_token: str = '' @@ -228,6 +229,9 @@ def __init__( print("Generations Params:\n" + "\n".join([f"{attr}: {getattr(self.generator_transformers_settings, attr, None)}" for attr, _ in self.generator_transformers_settings.__annotations__.items()])) start = time.time() + + # if num_samples_for_reward_stats == 0 then no normalization is done + self.norm_reward_mean, self.norm_reward_std = self.reward_stats( model=self.model, dataloader=self.get_train_dataloader() ) @@ -268,6 +272,7 @@ def _get_rm_generator(self, reward_model: torch.nn.Module | PreTrainedModel) -> generator = RayRMSamplingGenerator( model=reward_model, tokenizer=self.tokenizer, + model_replicas=self.args.reward_model_replicas, )#TODO this type of critic is created for Reward models with CausalLM head and utilize vllm engines case CriticType.DISTRIBUTED_VLLM: generator = ... @@ -527,8 +532,8 @@ def reward_stats( self, model: torch.nn.Module | PreTrainedModel, dataloader: torch.utils.data.DataLoader ) -> tuple[float, float]: if self.args.num_samples_for_reward_stats == 0: - norm_reward_mean = 0 - norm_reward_std = 1 + norm_reward_mean = 0.0 + norm_reward_std = 1.0 else: all_rewards = [] samples_processed = 0