From 00446d9d4dcb5974507a8776ace4bbfc463008a1 Mon Sep 17 00:00:00 2001 From: "n.surnachev" Date: Mon, 2 Dec 2024 18:56:40 +0300 Subject: [PATCH] debug --- .../trainers/online/reference_actor.py | 45 ++++++++++++++++--- turbo_alignment/trainers/online/reinforce.py | 45 ++++++++++++++----- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/turbo_alignment/trainers/online/reference_actor.py b/turbo_alignment/trainers/online/reference_actor.py index bcbf633..e08128b 100644 --- a/turbo_alignment/trainers/online/reference_actor.py +++ b/turbo_alignment/trainers/online/reference_actor.py @@ -3,8 +3,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn.functional as F +from torch.cuda.amp import autocast import gc + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + @ray.remote(num_gpus=1) class ReferenceModel(DistributedTorchRayActor): def __init__(self, world_size, rank, local_rank, master_addr, master_port): @@ -15,7 +23,16 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): def init_model_from_pretrained(self, pretrain): self._setup_distributed() self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16) #attn_implementation='flash_attention_2' + + disable_dropout_in_model(self.model) + + hash = 0 + for p in self.model.parameters(): + hash += p.data.sum().item() + print("RAY REFERENCE MODEL HASH: ", hash) + self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True) + self.model.resize_token_embeddings(len(self.tokenizer)) print(f"Reference model initialized on Node {self.node_id}, Local Rank {self.local_rank}") print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"])) @@ -38,21 +55,35 @@ def forward(self, x): @torch.no_grad def reference_forward(self, x, temperature, loss_mask): torch.cuda.empty_cache() - self.model.eval() x = {k: v.cuda() for k, v in x.items()} print(f"{x.keys()}") - logits = self.model(**x).logits[:, :-1] # 35GB - logits /= temperature + # logits = self.model(**x).logits[:, :-1] # 35GB + # logits /= temperature #logits = F.log_softmax(logits, dim=-1) # 35GB ''' Memory Efficient implementation of log_softmax using in_place operation ''' - torch.exp(logits, out=logits) - summed = torch.sum(logits, dim=-1, keepdim=True) - logits /= summed - torch.log(logits, out=logits) + # torch.exp(logits, out=logits) + # summed = torch.sum(logits, dim=-1, keepdim=True) + # logits /= summed + # torch.log(logits, out=logits) + + raw_logits = self.model( + input_ids=x['input_ids'], + attention_mask=x['attention_mask'], + position_ids=x['position_ids'], + use_cache=False, + ).logits[:, :-1] / temperature + + with autocast(dtype=torch.float32): + logits = F.log_softmax( + raw_logits, + dim=-1 + ) + + print(f"LOGITS FROM REFERENCE POLICY: {raw_logits} ; SUM: {raw_logits.sum()}") logprob = torch.gather(logits, 2, x['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1) logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index da31ba2..7a24f5d 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -44,6 +44,12 @@ from deepspeed.runtime.engine import DeepSpeedEngine import gc + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + # FIXME @dataclass class REINFORCETrainingArguments(TrainingArguments): @@ -156,6 +162,7 @@ def __init__( # TODO: TODO_RLOO watch later # if self.ref_model is not None: # self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled) + disable_dropout_in_model(self.model) self.reward_model = reward_model # TODO: delete later @@ -375,7 +382,7 @@ def get_batch_loss_metrics( input_ids=query_response, attention_mask=attention_mask, position_ids=position_ids, - loss_mask=response_tokens_mask.to(torch.bfloat16), + loss_mask=response_tokens_mask, ) logging.info(f'reference elapsed time:{time.time() - start}') rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response) @@ -392,6 +399,9 @@ def get_batch_loss_metrics( position_ids=position_ids, loss_mask=response_tokens_mask, ) + + print(f"Logprob from training policy: {logprobs}; Logprob from reference policy: {ref_logprobs}") + logging.info(f'policy logrobs elapsed time:{time.time() - start}') with torch.no_grad(): kl_term = logprobs.detach() - ref_logprobs @@ -530,27 +540,38 @@ def get_logprobs( #'use_cache': False } + lp = None if isinstance(model, RayGroup): #TODO only one reference model -> maybe delete from group?? - return ray.get(model.reference_forward(records, self.args.temperature, loss_mask)) + lp = ray.get(model.reference_forward(records, self.args.temperature, loss_mask)) else: + + hash = 0 + for p in model.parameters(): + hash += p.data.sum().item() + print("TRAINABLE MODEL HASH: ", hash) + + raw_logits = model( + input_ids=records['input_ids'], + attention_mask=records['attention_mask'], + position_ids=records['position_ids'], + use_cache=False, + ).logits[:, :-1] / self.args.temperature + + import logging + logging.info(f"LOGITS FROM TRAINING POLICY: {raw_logits} ; SUM: {raw_logits.sum()}") + # Memory efficient - a chain operations - logits = F.log_softmax( - model( - input_ids=records['input_ids'], - attention_mask=records['attention_mask'], - position_ids=records['position_ids'], - use_cache=False, - ).logits[:, :-1] / self.args.temperature, - dim=-1 - ) + logits = F.log_softmax(raw_logits, dim=-1) # logits /= self.args.temperature # all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(logits, 2, records['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1) logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 - return logprob.sum(-1) + lp = logprob.sum(-1) + + return lp def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]: if self.args.non_eos_penalty: