Skip to content

Commit

Permalink
reference model update
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Dec 3, 2024
1 parent 1892559 commit 0837aee
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 65 deletions.
11 changes: 6 additions & 5 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,20 @@ def reinforce_training(
from turbo_alignment.trainers.online.ray.rayactor_group import RayGroup
from turbo_alignment.trainers.online.ray.vllm_engine import create_vllm_engines
from turbo_alignment.trainers.online.reward_actor import RewardModel
from turbo_alignment.trainers.online.reference_actor import ReferenceModel
# from turbo_alignment.trainers.online.reference_actor import ReferenceModel

ray.init(address="auto")

experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path)

policy_models = RayGroup(num_nodes=2, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy)#64.19 GiB is allocated by PyTorch, and 3.40 GiB
policy_models = RayGroup(num_nodes=2, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy)
reward_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=RewardModel)
reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel)
# reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel)

# TODO_RLOO if possible hide init inside RayGroup
ray.get(policy_models.async_init_model_from_pretrained())
ray.get(reward_model.async_init_model_from_pretrained(rm_model=experiment_settings.reward_model_settings.model_path))
ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path))
# ray.get(reference_model.async_init_model_from_pretrained(pretrain=experiment_settings.model_settings.model_path))

'''
TODO_RLOO:
Expand All @@ -159,5 +159,6 @@ def reinforce_training(
ray.get(policy_models.async_fit_actor_model(
experiment_settings=experiment_settings,
vllm_engines=vllm_engines,
reference_model=reference_model, reward_model=reward_model
# reference_model=reference_model,
reward_model=reward_model
))
5 changes: 5 additions & 0 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def _load_pretrained_adapters(
is_trainable=model_settings.is_trainable,
)

def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0


def unfreeze_params(layer):
for param in layer.parameters():
Expand Down
30 changes: 6 additions & 24 deletions turbo_alignment/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,6 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port):

def init_model_from_pretrained(self):
self._setup_distributed()

# self.ds_config = get_train_ds_config(offload=False)
# self.ds_config["train_micro_batch_size_per_gpu"] = 1

# self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda')
# self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True)
# print(f"PolicyModel initialized on Node {self.node_id}, Local Rank {self.local_rank}")
# print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"]))

# def tokenize(self, text: str):
# return self.tokenizer(text, return_tensors='pt')

# def generate(self, text: str):
# tokenized_input = self.tokenize(text).to('cuda')
# return self.model(**tokenized_input)

@staticmethod
def _get_data_collator(
Expand Down Expand Up @@ -157,7 +142,7 @@ def _get_trainer(
vllm_engines,
training_args: REINFORCETrainingArguments,
experiment_settings: REINFORCETrainExperimentSettings,
ref_model,
# ref_model,
reward_model,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -169,9 +154,9 @@ def _get_trainer(

# TODO: TODO_RLOO load reference and reward model here

# ref_model = load_model(experiment_settings.model_settings, tokenizer)
# for _, param in ref_model.named_parameters():
# param.requires_grad = False
ref_model = load_model(experiment_settings.model_settings, tokenizer)
for _, param in ref_model.named_parameters():
param.requires_grad = False

# ref_model.eval()

Expand Down Expand Up @@ -222,7 +207,7 @@ def _get_datasets(self, experiment_settings: REINFORCETrainExperimentSettings) -
get rid off vllm_engines, reference_model, reward_model if possible
only get_trainer affected
'''
def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_model, reward_model) -> None:
def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reward_model) -> None: #reference_model
training_args = self._get_training_args(experiment_settings)

print('HERE!!!!!!!!!!!!!!!!!!!!!!!!')
Expand Down Expand Up @@ -276,7 +261,7 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_
vllm_engines,
training_args,
experiment_settings,
reference_model,
# reference_model,
reward_model,
self.model,
self.tokenizer,
Expand All @@ -285,8 +270,6 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_
data_collator,
)
print(f"Elapsed get_trainer time: {time.time() - start} seconds")

start = time.time()

if self.trainer.accelerator.is_main_process:
self._dataset_and_collator_sanity_check(train_dataset, data_collator)
Expand All @@ -302,7 +285,6 @@ def run(self, experiment_settings: ExperimentSettingsT, vllm_engines, reference_
self._save_experiment_metadata(
experiment_metadata, Path(self.trainer.args.output_dir) / 'experiment_metadata.json'
)
print(f"Elapsed before trainer.train() time: {time.time() - start} seconds")
self.trainer.train()

self.trainer.save_model()
4 changes: 2 additions & 2 deletions turbo_alignment/trainers/online/ray/rayactor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def async_fit_actor_model(
self,
experiment_settings: REINFORCETrainExperimentSettings,
vllm_engines: List,
reference_model,
# reference_model,
reward_model,
):
refs = []
Expand All @@ -148,7 +148,7 @@ def async_fit_actor_model(
actor.run.remote(
experiment_settings=experiment_settings,
vllm_engines=vllm_engines,
reference_model=reference_model,
# reference_model=reference_model,
reward_model=reward_model,

)
Expand Down
8 changes: 5 additions & 3 deletions turbo_alignment/trainers/online/reference_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port):

def init_model_from_pretrained(self, pretrain):
self._setup_distributed()
self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16) #attn_implementation='flash_attention_2'
self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16, attn_implementation='flash_attention_2') #attn_implementation='flash_attention_2'
self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True)
print(f"Reference model initialized on Node {self.node_id}, Local Rank {self.local_rank}")
print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"]))
Expand Down Expand Up @@ -42,12 +42,14 @@ def reference_forward(self, x, temperature, loss_mask):
x = {k: v.cuda() for k, v in x.items()}

print(f"{x.keys()}")
logits = self.model(**x).logits[:, :-1] # 35GB
logits = self.model(**x).logits[:, :-1]

logits /= temperature

#logits = F.log_softmax(logits, dim=-1) # 35GB
'''
Memory Efficient implementation of log_softmax using in_place operation
equivalent to:
logits = F.log_softmax(logits, dim=-1)
'''
torch.exp(logits, out=logits)
summed = torch.sum(logits, dim=-1, keepdim=True)
Expand Down
50 changes: 23 additions & 27 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.distributed
import torch.nn.functional as F
import torch.utils
import torch.utils.data
Expand Down Expand Up @@ -31,7 +32,7 @@
ActorType,
RewardProcessorType,
)

from turbo_alignment.common.tf.loaders.model.model import disable_dropout_in_model
from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings
from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer
from turbo_alignment.trainers.online.reward_processor import RewardProcessor
Expand Down Expand Up @@ -98,12 +99,12 @@ def __init__(
**kwargs,
)
logging.info(f'super().__init__ elapsed time:{time.time() - start}')
print(f'{self.accelerator.num_processes=}', flush=True)

self.vllm_engines = vllm_engines
start = time.time()
if self.vllm_engines is not None and torch.distributed.get_rank() == 0:
master_address = ray._private.services.get_node_ip_address()
print(f'TRAINER DEBUG: {master_address=}', flush=True)
with socket.socket() as sock:
sock.bind(("", 0))
master_port = sock.getsockname()[1]
Expand Down Expand Up @@ -150,16 +151,15 @@ def __init__(
logging.info(f'distributed vllm engine __init__ elapsed time:{time.time() - start}')

self.ref_model = ref_model
# TODO: delete later
# ray.get(self.ref_model.async_eval())

# TODO: TODO_RLOO watch later
# if self.ref_model is not None:
# self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled)
if self.ref_model is not None:
self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled)

disable_dropout_in_model(self.model)
disable_dropout_in_model(self.ref_model)

self.reward_model = reward_model
# TODO: delete later
# ray.get(self.ref_model.async_eval())

self._stored_metrics: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
self.kl_coef = args.kl_coef
Expand Down Expand Up @@ -202,7 +202,6 @@ def __init__(
)
logging.info(f'statictis in __init__ elapsed time:{time.time() - start}')
self.print_readable_stats()


def _broadcast_to_vllm(self, model: DeepSpeedEngine):
# avoid OOM
Expand Down Expand Up @@ -297,18 +296,8 @@ def get_answers_and_rewards(
)
logging.info(f'generations elapsed time:{time.time() - start}')

for g in generations:
for ans in g.answers:
print(f'{g.input_token_ids.device=}, {ans.answer_token_ids.device=}', flush=True)
print(f'{g.input_attention_mask.device=}, {ans.answer_attention_mask.device=}', flush=True)
break
break
response_ids = [torch.cat([g.input_token_ids, ans.answer_token_ids], dim=1) for g in generations for ans in g.answers]
response_attention_mask = [torch.cat([g.input_attention_mask, ans.answer_attention_mask], dim=1) for g in generations for ans in g.answers]

import random
ind = random.randint(0, len(response_ids) - 1)
assert response_ids[ind].shape == response_attention_mask[ind].shape

if torch.distributed.get_rank() == 0:
print(f'Prompt with completion at index [0] shape: {response_ids[0].shape}', flush=True)
Expand Down Expand Up @@ -375,16 +364,18 @@ def get_batch_loss_metrics(
input_ids=query_response,
attention_mask=attention_mask,
position_ids=position_ids,
loss_mask=response_tokens_mask.to(torch.bfloat16),
loss_mask=response_tokens_mask,
)

logging.info(f'reference elapsed time:{time.time() - start}')
rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response)

start = time.time()

del inputs
gc.collect()
torch.cuda.empty_cache()

start = time.time()
logprobs = self.get_logprobs(
model=model,
input_ids=query_response,
Expand All @@ -393,11 +384,16 @@ def get_batch_loss_metrics(
loss_mask=response_tokens_mask,
)
logging.info(f'policy logrobs elapsed time:{time.time() - start}')

# print(f'rank {torch.distributed.get_rank()}:{ref_logprobs.shape=}, {logprobs.shape=}', flush=True)
# print(f'rank {torch.distributed.get_rank()}:{ref_logprobs=}, {logprobs=}', flush=True)
# print(f'rank {torch.distributed.get_rank()}:{(ref_logprobs[0, -1, :] - logprobs[0, -1, :]).abs().sum()=}', flush=True)
# assert torch.allclose(ref_logprobs, logprobs)

with torch.no_grad():
kl_term = logprobs.detach() - ref_logprobs
regularized_rewards = rewards - self.kl_coef * kl_term

print(f"{regularized_rewards.shape=}", flush=True)
baselined_reward, baseline_metrics = self.reward_processor.baseline_rewards(rewards=regularized_rewards)

loss = -baselined_reward * logprobs
Expand Down Expand Up @@ -457,7 +453,7 @@ def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in
loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train')

self.store_metrics(metrics=metrics, train_eval='train')
logging.info(f'Compute Loss elapsed time:{time.time() - start}')
logging.info(f'get_batch_loss_metrics elapsed time:{time.time() - start}')
gc.collect()
torch.cuda.empty_cache()

Expand Down Expand Up @@ -526,21 +522,21 @@ def get_logprobs(
records = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.cuda(),
#'position_ids': position_ids.cuda(),
#'use_cache': False
}

if isinstance(model, RayGroup):
#TODO only one reference model -> maybe delete from group??
return ray.get(model.reference_forward(records, self.args.temperature, loss_mask))
else:
# Memory efficient - a chain operations

logits = F.log_softmax(
model(
input_ids=records['input_ids'],
attention_mask=records['attention_mask'],
position_ids=records['position_ids'],
use_cache=False,
#position_ids=records['position_ids'],
#use_cache=False,
).logits[:, :-1] / self.args.temperature,
dim=-1
)
Expand Down
7 changes: 3 additions & 4 deletions turbo_alignment/trainers/online/reward_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ def forward(self, x):
x = {k: v.cuda() for k, v in x.items()}
# print('\n\n\nRM model', [v.shape for k, v in x.items()], 'RM model\n\n\n')
# print(self.tokenizer.decode(x['input_ids'][0], skip_special_tokens=False))
#TODO reward from eos

for k, v in x.items():
if isinstance(v, torch.Tensor):
print(f'REWARD MODEL:{v.shape=}', flush=True)
# for k, v in x.items():
# if isinstance(v, torch.Tensor):
# print(f'REWARD MODEL:{v.shape=}', flush=True)

return self.model(**x).logits

0 comments on commit 0837aee

Please sign in to comment.