Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
syrn1k committed Dec 2, 2024
1 parent 00446d9 commit 4db83cd
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 31 deletions.
7 changes: 2 additions & 5 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,17 @@ def reinforce_training(
from turbo_alignment.trainers.online.ray.rayactor_group import RayGroup
from turbo_alignment.trainers.online.ray.vllm_engine import create_vllm_engines
from turbo_alignment.trainers.online.reward_actor import RewardModel
from turbo_alignment.trainers.online.reference_actor import ReferenceModel


ray.init(address="auto")

experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path)

policy_models = RayGroup(num_nodes=2, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy)#64.19 GiB is allocated by PyTorch, and 3.40 GiB
reward_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=RewardModel)
reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel)

# TODO_RLOO if possible hide init inside RayGroup
ray.get(policy_models.async_init_model_from_pretrained())
ray.get(reward_model.async_init_model_from_pretrained(rm_model=experiment_settings.reward_model_settings.model_path))
ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path))

'''
TODO_RLOO:
Expand All @@ -159,5 +156,5 @@ def reinforce_training(
ray.get(policy_models.async_fit_actor_model(
experiment_settings=experiment_settings,
vllm_engines=vllm_engines,
reference_model=reference_model, reward_model=reward_model
reward_model=reward_model
))
10 changes: 4 additions & 6 deletions turbo_alignment/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def _get_trainer(
vllm_engines,
training_args: REINFORCETrainingArguments,
experiment_settings: REINFORCETrainExperimentSettings,
ref_model,
reward_model,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -169,9 +168,9 @@ def _get_trainer(

# TODO: TODO_RLOO load reference and reward model here

# ref_model = load_model(experiment_settings.model_settings, tokenizer)
# for _, param in ref_model.named_parameters():
# param.requires_grad = False
ref_model = load_model(experiment_settings.model_settings, tokenizer)
for _, param in ref_model.named_parameters():
param.requires_grad = False

# ref_model.eval()

Expand Down Expand Up @@ -222,7 +221,7 @@ def _get_datasets(self, experiment_settings: REINFORCETrainExperimentSettings) -
get rid off vllm_engines, reference_model, reward_model if possible
only get_trainer affected
'''
def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_model, reward_model) -> None:
def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reward_model) -> None:
training_args = self._get_training_args(experiment_settings)

print('HERE!!!!!!!!!!!!!!!!!!!!!!!!')
Expand Down Expand Up @@ -276,7 +275,6 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_
vllm_engines,
training_args,
experiment_settings,
reference_model,
reward_model,
self.model,
self.tokenizer,
Expand Down
3 changes: 0 additions & 3 deletions turbo_alignment/trainers/online/ray/rayactor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def async_fit_actor_model(
self,
experiment_settings: REINFORCETrainExperimentSettings,
vllm_engines: List,
reference_model,
reward_model,
):
refs = []
Expand All @@ -148,9 +147,7 @@ def async_fit_actor_model(
actor.run.remote(
experiment_settings=experiment_settings,
vllm_engines=vllm_engines,
reference_model=reference_model,
reward_model=reward_model,

)
)

Expand Down
28 changes: 11 additions & 17 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,12 @@ def __init__(
logging.info(f'distributed vllm engine __init__ elapsed time:{time.time() - start}')

self.ref_model = ref_model
# TODO: delete later
# ray.get(self.ref_model.async_eval())

# TODO: TODO_RLOO watch later
# if self.ref_model is not None:
# self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled)
if self.ref_model is not None:
self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled)

disable_dropout_in_model(self.model)
disable_dropout_in_model(self.ref_model)

self.reward_model = reward_model
# TODO: delete later
Expand Down Expand Up @@ -377,7 +376,7 @@ def get_batch_loss_metrics(
do_broadcast=True if train_eval == 'train' else False
)
start = time.time()
ref_logprobs = self.get_logprobs(
ref_logits, ref_logprobs = self.get_logprobs(
model=self.ref_model,
input_ids=query_response,
attention_mask=attention_mask,
Expand All @@ -392,15 +391,18 @@ def get_batch_loss_metrics(
gc.collect()
torch.cuda.empty_cache()

logprobs = self.get_logprobs(
logits, logprobs = self.get_logprobs(
model=model,
input_ids=query_response,
attention_mask=attention_mask,
position_ids=position_ids,
loss_mask=response_tokens_mask,
)

print(f"Logprob from training policy: {logprobs}; Logprob from reference policy: {ref_logprobs}")
print(f"LOGITS ALLCLOSE: {torch.allclose(ref_logits, logits)}")
print(f"LOGPROBS ALLCLOSE: {torch.allclose(ref_logprobs, logprobs)}")

print(f"REF LOGRPOBS: {ref_logprobs}; LOGRPOBS: {logprobs}")

logging.info(f'policy logrobs elapsed time:{time.time() - start}')
with torch.no_grad():
Expand Down Expand Up @@ -545,22 +547,13 @@ def get_logprobs(
#TODO only one reference model -> maybe delete from group??
lp = ray.get(model.reference_forward(records, self.args.temperature, loss_mask))
else:

hash = 0
for p in model.parameters():
hash += p.data.sum().item()
print("TRAINABLE MODEL HASH: ", hash)

raw_logits = model(
input_ids=records['input_ids'],
attention_mask=records['attention_mask'],
position_ids=records['position_ids'],
use_cache=False,
).logits[:, :-1] / self.args.temperature

import logging
logging.info(f"LOGITS FROM TRAINING POLICY: {raw_logits} ; SUM: {raw_logits.sum()}")

# Memory efficient - a chain operations
logits = F.log_softmax(raw_logits, dim=-1)

Expand All @@ -571,6 +564,7 @@ def get_logprobs(

lp = logprob.sum(-1)

return logits, lp
return lp

def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit 4db83cd

Please sign in to comment.