From 0837aee8ca1726f958f0770a9396b9cac4a260e4 Mon Sep 17 00:00:00 2001 From: Almaz Dautov Date: Tue, 3 Dec 2024 18:22:56 +0300 Subject: [PATCH] reference model update --- turbo_alignment/cli/train.py | 11 ++-- .../common/tf/loaders/model/model.py | 5 ++ turbo_alignment/pipelines/train/reinforce.py | 30 +++-------- .../trainers/online/ray/rayactor_group.py | 4 +- .../trainers/online/reference_actor.py | 8 +-- turbo_alignment/trainers/online/reinforce.py | 50 +++++++++---------- .../trainers/online/reward_actor.py | 7 ++- 7 files changed, 50 insertions(+), 65 deletions(-) diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index d50792e..314a85c 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -124,20 +124,20 @@ def reinforce_training( from turbo_alignment.trainers.online.ray.rayactor_group import RayGroup from turbo_alignment.trainers.online.ray.vllm_engine import create_vllm_engines from turbo_alignment.trainers.online.reward_actor import RewardModel - from turbo_alignment.trainers.online.reference_actor import ReferenceModel + # from turbo_alignment.trainers.online.reference_actor import ReferenceModel ray.init(address="auto") experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path) - policy_models = RayGroup(num_nodes=2, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy)#64.19 GiB is allocated by PyTorch, and 3.40 GiB + policy_models = RayGroup(num_nodes=2, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy) reward_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=RewardModel) - reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel) + # reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel) # TODO_RLOO if possible hide init inside RayGroup ray.get(policy_models.async_init_model_from_pretrained()) ray.get(reward_model.async_init_model_from_pretrained(rm_model=experiment_settings.reward_model_settings.model_path)) - ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path)) + # ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path)) ''' TODO_RLOO: @@ -159,5 +159,6 @@ def reinforce_training( ray.get(policy_models.async_fit_actor_model( experiment_settings=experiment_settings, vllm_engines=vllm_engines, - reference_model=reference_model, reward_model=reward_model + # reference_model=reference_model, + reward_model=reward_model )) diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index 1861b0c..57d89b3 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -40,6 +40,11 @@ def _load_pretrained_adapters( is_trainable=model_settings.is_trainable, ) +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + def unfreeze_params(layer): for param in layer.parameters(): diff --git a/turbo_alignment/pipelines/train/reinforce.py b/turbo_alignment/pipelines/train/reinforce.py index 60167b3..7e6c192 100644 --- a/turbo_alignment/pipelines/train/reinforce.py +++ b/turbo_alignment/pipelines/train/reinforce.py @@ -91,21 +91,6 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): def init_model_from_pretrained(self): self._setup_distributed() - - # self.ds_config = get_train_ds_config(offload=False) - # self.ds_config["train_micro_batch_size_per_gpu"] = 1 - - # self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda') - # self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True) - # print(f"PolicyModel initialized on Node {self.node_id}, Local Rank {self.local_rank}") - # print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"])) - - # def tokenize(self, text: str): - # return self.tokenizer(text, return_tensors='pt') - - # def generate(self, text: str): - # tokenized_input = self.tokenize(text).to('cuda') - # return self.model(**tokenized_input) @staticmethod def _get_data_collator( @@ -157,7 +142,7 @@ def _get_trainer( vllm_engines, training_args: REINFORCETrainingArguments, experiment_settings: REINFORCETrainExperimentSettings, - ref_model, + # ref_model, reward_model, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, @@ -169,9 +154,9 @@ def _get_trainer( # TODO: TODO_RLOO load reference and reward model here - # ref_model = load_model(experiment_settings.model_settings, tokenizer) - # for _, param in ref_model.named_parameters(): - # param.requires_grad = False + ref_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in ref_model.named_parameters(): + param.requires_grad = False # ref_model.eval() @@ -222,7 +207,7 @@ def _get_datasets(self, experiment_settings: REINFORCETrainExperimentSettings) - get rid off vllm_engines, reference_model, reward_model if possible only get_trainer affected ''' - def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_model, reward_model) -> None: + def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reward_model) -> None: #reference_model training_args = self._get_training_args(experiment_settings) print('HERE!!!!!!!!!!!!!!!!!!!!!!!!') @@ -276,7 +261,7 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ vllm_engines, training_args, experiment_settings, - reference_model, + # reference_model, reward_model, self.model, self.tokenizer, @@ -285,8 +270,6 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ data_collator, ) print(f"Elapsed get_trainer time: {time.time() - start} seconds") - - start = time.time() if self.trainer.accelerator.is_main_process: self._dataset_and_collator_sanity_check(train_dataset, data_collator) @@ -302,7 +285,6 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_ self._save_experiment_metadata( experiment_metadata, Path(self.trainer.args.output_dir) / 'experiment_metadata.json' ) - print(f"Elapsed before trainer.train() time: {time.time() - start} seconds") self.trainer.train() self.trainer.save_model() diff --git a/turbo_alignment/trainers/online/ray/rayactor_group.py b/turbo_alignment/trainers/online/ray/rayactor_group.py index 6ef4c2a..80e16bf 100644 --- a/turbo_alignment/trainers/online/ray/rayactor_group.py +++ b/turbo_alignment/trainers/online/ray/rayactor_group.py @@ -138,7 +138,7 @@ def async_fit_actor_model( self, experiment_settings: REINFORCETrainExperimentSettings, vllm_engines: List, - reference_model, + # reference_model, reward_model, ): refs = [] @@ -148,7 +148,7 @@ def async_fit_actor_model( actor.run.remote( experiment_settings=experiment_settings, vllm_engines=vllm_engines, - reference_model=reference_model, + # reference_model=reference_model, reward_model=reward_model, ) diff --git a/turbo_alignment/trainers/online/reference_actor.py b/turbo_alignment/trainers/online/reference_actor.py index bcbf633..1223ca1 100644 --- a/turbo_alignment/trainers/online/reference_actor.py +++ b/turbo_alignment/trainers/online/reference_actor.py @@ -14,7 +14,7 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): def init_model_from_pretrained(self, pretrain): self._setup_distributed() - self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16) #attn_implementation='flash_attention_2' + self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16, attn_implementation='flash_attention_2') #attn_implementation='flash_attention_2' self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True) print(f"Reference model initialized on Node {self.node_id}, Local Rank {self.local_rank}") print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"])) @@ -42,12 +42,14 @@ def reference_forward(self, x, temperature, loss_mask): x = {k: v.cuda() for k, v in x.items()} print(f"{x.keys()}") - logits = self.model(**x).logits[:, :-1] # 35GB + logits = self.model(**x).logits[:, :-1] + logits /= temperature - #logits = F.log_softmax(logits, dim=-1) # 35GB ''' Memory Efficient implementation of log_softmax using in_place operation + equivalent to: + logits = F.log_softmax(logits, dim=-1) ''' torch.exp(logits, out=logits) summed = torch.sum(logits, dim=-1, keepdim=True) diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index da31ba2..eef8f3c 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union import torch +import torch.distributed import torch.nn.functional as F import torch.utils import torch.utils.data @@ -31,7 +32,7 @@ ActorType, RewardProcessorType, ) - +from turbo_alignment.common.tf.loaders.model.model import disable_dropout_in_model from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer from turbo_alignment.trainers.online.reward_processor import RewardProcessor @@ -98,12 +99,12 @@ def __init__( **kwargs, ) logging.info(f'super().__init__ elapsed time:{time.time() - start}') + print(f'{self.accelerator.num_processes=}', flush=True) self.vllm_engines = vllm_engines start = time.time() if self.vllm_engines is not None and torch.distributed.get_rank() == 0: master_address = ray._private.services.get_node_ip_address() - print(f'TRAINER DEBUG: {master_address=}', flush=True) with socket.socket() as sock: sock.bind(("", 0)) master_port = sock.getsockname()[1] @@ -150,16 +151,15 @@ def __init__( logging.info(f'distributed vllm engine __init__ elapsed time:{time.time() - start}') self.ref_model = ref_model - # TODO: delete later - # ray.get(self.ref_model.async_eval()) # TODO: TODO_RLOO watch later - # if self.ref_model is not None: - # self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled) + if self.ref_model is not None: + self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled) + + disable_dropout_in_model(self.model) + disable_dropout_in_model(self.ref_model) self.reward_model = reward_model - # TODO: delete later - # ray.get(self.ref_model.async_eval()) self._stored_metrics: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) self.kl_coef = args.kl_coef @@ -202,7 +202,6 @@ def __init__( ) logging.info(f'statictis in __init__ elapsed time:{time.time() - start}') self.print_readable_stats() - def _broadcast_to_vllm(self, model: DeepSpeedEngine): # avoid OOM @@ -297,18 +296,8 @@ def get_answers_and_rewards( ) logging.info(f'generations elapsed time:{time.time() - start}') - for g in generations: - for ans in g.answers: - print(f'{g.input_token_ids.device=}, {ans.answer_token_ids.device=}', flush=True) - print(f'{g.input_attention_mask.device=}, {ans.answer_attention_mask.device=}', flush=True) - break - break response_ids = [torch.cat([g.input_token_ids, ans.answer_token_ids], dim=1) for g in generations for ans in g.answers] response_attention_mask = [torch.cat([g.input_attention_mask, ans.answer_attention_mask], dim=1) for g in generations for ans in g.answers] - - import random - ind = random.randint(0, len(response_ids) - 1) - assert response_ids[ind].shape == response_attention_mask[ind].shape if torch.distributed.get_rank() == 0: print(f'Prompt with completion at index [0] shape: {response_ids[0].shape}', flush=True) @@ -375,16 +364,18 @@ def get_batch_loss_metrics( input_ids=query_response, attention_mask=attention_mask, position_ids=position_ids, - loss_mask=response_tokens_mask.to(torch.bfloat16), + loss_mask=response_tokens_mask, ) + logging.info(f'reference elapsed time:{time.time() - start}') rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response) - start = time.time() + del inputs gc.collect() torch.cuda.empty_cache() + start = time.time() logprobs = self.get_logprobs( model=model, input_ids=query_response, @@ -393,11 +384,16 @@ def get_batch_loss_metrics( loss_mask=response_tokens_mask, ) logging.info(f'policy logrobs elapsed time:{time.time() - start}') + + # print(f'rank {torch.distributed.get_rank()}:{ref_logprobs.shape=}, {logprobs.shape=}', flush=True) + # print(f'rank {torch.distributed.get_rank()}:{ref_logprobs=}, {logprobs=}', flush=True) + # print(f'rank {torch.distributed.get_rank()}:{(ref_logprobs[0, -1, :] - logprobs[0, -1, :]).abs().sum()=}', flush=True) + # assert torch.allclose(ref_logprobs, logprobs) + with torch.no_grad(): kl_term = logprobs.detach() - ref_logprobs regularized_rewards = rewards - self.kl_coef * kl_term - print(f"{regularized_rewards.shape=}", flush=True) baselined_reward, baseline_metrics = self.reward_processor.baseline_rewards(rewards=regularized_rewards) loss = -baselined_reward * logprobs @@ -457,7 +453,7 @@ def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train') self.store_metrics(metrics=metrics, train_eval='train') - logging.info(f'Compute Loss elapsed time:{time.time() - start}') + logging.info(f'get_batch_loss_metrics elapsed time:{time.time() - start}') gc.collect() torch.cuda.empty_cache() @@ -526,7 +522,7 @@ def get_logprobs( records = { 'input_ids': input_ids.cuda(), 'attention_mask': attention_mask.cuda(), - 'position_ids': position_ids.cuda(), + #'position_ids': position_ids.cuda(), #'use_cache': False } @@ -534,13 +530,13 @@ def get_logprobs( #TODO only one reference model -> maybe delete from group?? return ray.get(model.reference_forward(records, self.args.temperature, loss_mask)) else: - # Memory efficient - a chain operations + logits = F.log_softmax( model( input_ids=records['input_ids'], attention_mask=records['attention_mask'], - position_ids=records['position_ids'], - use_cache=False, + #position_ids=records['position_ids'], + #use_cache=False, ).logits[:, :-1] / self.args.temperature, dim=-1 ) diff --git a/turbo_alignment/trainers/online/reward_actor.py b/turbo_alignment/trainers/online/reward_actor.py index beceae2..0c0d755 100644 --- a/turbo_alignment/trainers/online/reward_actor.py +++ b/turbo_alignment/trainers/online/reward_actor.py @@ -38,10 +38,9 @@ def forward(self, x): x = {k: v.cuda() for k, v in x.items()} # print('\n\n\nRM model', [v.shape for k, v in x.items()], 'RM model\n\n\n') # print(self.tokenizer.decode(x['input_ids'][0], skip_special_tokens=False)) - #TODO reward from eos - for k, v in x.items(): - if isinstance(v, torch.Tensor): - print(f'REWARD MODEL:{v.shape=}', flush=True) + # for k, v in x.items(): + # if isinstance(v, torch.Tensor): + # print(f'REWARD MODEL:{v.shape=}', flush=True) return self.model(**x).logits \ No newline at end of file