Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
syrn1k committed Dec 2, 2024
1 parent 1892559 commit 00446d9
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 19 deletions.
45 changes: 38 additions & 7 deletions turbo_alignment/trainers/online/reference_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
import gc


def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0


@ray.remote(num_gpus=1)
class ReferenceModel(DistributedTorchRayActor):
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
Expand All @@ -15,7 +23,16 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port):
def init_model_from_pretrained(self, pretrain):
self._setup_distributed()
self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16) #attn_implementation='flash_attention_2'

disable_dropout_in_model(self.model)

hash = 0
for p in self.model.parameters():
hash += p.data.sum().item()
print("RAY REFERENCE MODEL HASH: ", hash)

self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True)
self.model.resize_token_embeddings(len(self.tokenizer))
print(f"Reference model initialized on Node {self.node_id}, Local Rank {self.local_rank}")
print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"]))

Expand All @@ -38,21 +55,35 @@ def forward(self, x):
@torch.no_grad
def reference_forward(self, x, temperature, loss_mask):
torch.cuda.empty_cache()
self.model.eval()
x = {k: v.cuda() for k, v in x.items()}

print(f"{x.keys()}")
logits = self.model(**x).logits[:, :-1] # 35GB
logits /= temperature
# logits = self.model(**x).logits[:, :-1] # 35GB
# logits /= temperature

#logits = F.log_softmax(logits, dim=-1) # 35GB
'''
Memory Efficient implementation of log_softmax using in_place operation
'''
torch.exp(logits, out=logits)
summed = torch.sum(logits, dim=-1, keepdim=True)
logits /= summed
torch.log(logits, out=logits)
# torch.exp(logits, out=logits)
# summed = torch.sum(logits, dim=-1, keepdim=True)
# logits /= summed
# torch.log(logits, out=logits)

raw_logits = self.model(
input_ids=x['input_ids'],
attention_mask=x['attention_mask'],
position_ids=x['position_ids'],
use_cache=False,
).logits[:, :-1] / temperature

with autocast(dtype=torch.float32):
logits = F.log_softmax(
raw_logits,
dim=-1
)

print(f"LOGITS FROM REFERENCE POLICY: {raw_logits} ; SUM: {raw_logits.sum()}")

logprob = torch.gather(logits, 2, x['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)
logprob[~loss_mask[:, 1:].to(torch.bool)] = 0
Expand Down
45 changes: 33 additions & 12 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
from deepspeed.runtime.engine import DeepSpeedEngine
import gc


def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0

# FIXME
@dataclass
class REINFORCETrainingArguments(TrainingArguments):
Expand Down Expand Up @@ -156,6 +162,7 @@ def __init__(
# TODO: TODO_RLOO watch later
# if self.ref_model is not None:
# self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled)
disable_dropout_in_model(self.model)

self.reward_model = reward_model
# TODO: delete later
Expand Down Expand Up @@ -375,7 +382,7 @@ def get_batch_loss_metrics(
input_ids=query_response,
attention_mask=attention_mask,
position_ids=position_ids,
loss_mask=response_tokens_mask.to(torch.bfloat16),
loss_mask=response_tokens_mask,
)
logging.info(f'reference elapsed time:{time.time() - start}')
rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response)
Expand All @@ -392,6 +399,9 @@ def get_batch_loss_metrics(
position_ids=position_ids,
loss_mask=response_tokens_mask,
)

print(f"Logprob from training policy: {logprobs}; Logprob from reference policy: {ref_logprobs}")

logging.info(f'policy logrobs elapsed time:{time.time() - start}')
with torch.no_grad():
kl_term = logprobs.detach() - ref_logprobs
Expand Down Expand Up @@ -530,27 +540,38 @@ def get_logprobs(
#'use_cache': False
}

lp = None
if isinstance(model, RayGroup):
#TODO only one reference model -> maybe delete from group??
return ray.get(model.reference_forward(records, self.args.temperature, loss_mask))
lp = ray.get(model.reference_forward(records, self.args.temperature, loss_mask))
else:

hash = 0
for p in model.parameters():
hash += p.data.sum().item()
print("TRAINABLE MODEL HASH: ", hash)

raw_logits = model(
input_ids=records['input_ids'],
attention_mask=records['attention_mask'],
position_ids=records['position_ids'],
use_cache=False,
).logits[:, :-1] / self.args.temperature

import logging
logging.info(f"LOGITS FROM TRAINING POLICY: {raw_logits} ; SUM: {raw_logits.sum()}")

# Memory efficient - a chain operations
logits = F.log_softmax(
model(
input_ids=records['input_ids'],
attention_mask=records['attention_mask'],
position_ids=records['position_ids'],
use_cache=False,
).logits[:, :-1] / self.args.temperature,
dim=-1
)
logits = F.log_softmax(raw_logits, dim=-1)

# logits /= self.args.temperature
# all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(logits, 2, records['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)
logprob[~loss_mask[:, 1:].to(torch.bool)] = 0

return logprob.sum(-1)
lp = logprob.sum(-1)

return lp

def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]:
if self.args.non_eos_penalty:
Expand Down

0 comments on commit 00446d9

Please sign in to comment.