diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index aba80e9..7be3103 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -59,6 +59,8 @@ def load_model( **model_settings.model_kwargs, torch_dtype=torch.bfloat16, ) + import ray + print(f'node_id:{ray.get_runtime_context().get_node_id()}, gpu_id: {ray.get_runtime_context().get_accelerator_ids()["GPU"]}', flush=True) if model_settings.transformers_settings.load_in_8bit: model = prepare_model_for_int8_training(model) diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py index 41d56a6..48681ca 100755 --- a/turbo_alignment/generators/rm.py +++ b/turbo_alignment/generators/rm.py @@ -16,7 +16,9 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, **k def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEncoding) -> torch.Tensor: with torch.no_grad(): - rewards = ray.get(self._model.forward.remote(records).logits.cpu()) + # TODO assuming that only one reward model + records = {k: v.cuda() for k, v in records.items()} + rewards = ray.get(self._model.async_forward(records))[0].logits return rewards # .squeeze() def generate_from_batch( diff --git a/turbo_alignment/trainers/online/ray/rayactor_group.py b/turbo_alignment/trainers/online/ray/rayactor_group.py index e42b797..6a2ac41 100644 --- a/turbo_alignment/trainers/online/ray/rayactor_group.py +++ b/turbo_alignment/trainers/online/ray/rayactor_group.py @@ -100,6 +100,9 @@ def _initiate_actors(self, pg, num_gpus_per_actor): ).remote(world_size, rank, local_rank, master_addr, master_port) self._actor_handlers.append(worker_actor) + def async_forward(self, records: dict[str, torch.Tensor]): + return [actor.forward.remote(records) for actor in self._actor_handlers] + def async_eval(self): return [actor.eval.remote() for actor in self._actor_handlers] diff --git a/turbo_alignment/trainers/online/reference_actor.py b/turbo_alignment/trainers/online/reference_actor.py index 74cba8f..4b0769c 100644 --- a/turbo_alignment/trainers/online/reference_actor.py +++ b/turbo_alignment/trainers/online/reference_actor.py @@ -1,6 +1,7 @@ import ray from turbo_alignment.trainers.online.ray.distributed_torch_ray_actor import DistributedTorchRayActor from transformers import AutoModelForCausalLM, AutoTokenizer +import torch @ray.remote(num_gpus=1) class ReferenceModel(DistributedTorchRayActor): @@ -11,7 +12,7 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): def init_model_from_pretrained(self, pretrain): self._setup_distributed() - self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda') + self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16) self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True) print(f"Reference model initialized on Node {self.node_id}, Local Rank {self.local_rank}") print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"])) @@ -26,5 +27,8 @@ def generate(self, text: str): def eval(self): return self.model.eval() + @torch.no_grad def forward(self, x): + self.model.eval() + x = {k: v.cuda() for k, v in x.items()} return self.model(**x) \ No newline at end of file diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py index 23538ca..96bd463 100644 --- a/turbo_alignment/trainers/online/reinforce.py +++ b/turbo_alignment/trainers/online/reinforce.py @@ -102,6 +102,8 @@ def __init__( # TODO_RLOO assert tp_size same for all engines world_size = args.actor_settings["vllm_num_engines"] * args.actor_settings["vllm_tensor_parallel_size"] + 1 + + #TODO turn on nccl backend = "nccl" # https://github.com/OpenRLHF/OpenRLHF/issues/313 @@ -184,7 +186,7 @@ def __init__( model=self.model, dataloader=self.get_train_dataloader() ) - # def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor: + # def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor: # print('traning step', flush=True) # print(f'reinforce.py: {type(model)=}', flush=True) # assert True == False @@ -225,7 +227,6 @@ def _get_rm_generator(self, reward_model: torch.nn.Module | PreTrainedModel) -> generator = RayRMSamplingGenerator( model=reward_model, tokenizer=self.tokenizer, - accelerator=self.accelerator, )#TODO this type of critic is created for Reward models with CausalLM head and utilize vllm engines case CriticType.DISTRIBUTED_VLLM: generator = ... @@ -264,7 +265,8 @@ def get_answers_and_rewards( if do_broadcast: # TODO: move to generator - assert isinstance(model, DeepSpeedEngine) + print('broadcast', type(model), flush=True) + # assert isinstance(model, DeepSpeedEngine) self._broadcast_to_vllm(model) torch.distributed.barrier() @@ -322,7 +324,7 @@ def get_batch_loss_metrics( position_ids, rewards, ) = self.get_answers_and_rewards( - model=self.accelerator.unwrap_model(model), + model=model, inputs=inputs, ) @@ -377,6 +379,7 @@ def get_batch_loss_metrics( 'invalid_rewards', ], ): + tensor = tensor.cuda() metrics = { **metrics, **get_log_mean_std(tensor, name, train_eval), @@ -390,7 +393,7 @@ def get_batch_loss_metrics( return loss.mean(), metrics - def compute_loss(self, model, inputs, return_outputs: bool = False): + def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None): loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train') self.store_metrics(metrics=metrics, train_eval='train') @@ -450,16 +453,32 @@ def get_logprobs( position_ids: torch.Tensor, loss_mask: torch.Tensor, ): - logits = ray.get(model.forward.remote( - {'input_ids': input_ids, - 'attention_mask': attention_mask, - 'position_ids': position_ids, - 'use_cache': False,} - )).logits[:, :-1] - logits /= self.args.temperature + from turbo_alignment.trainers.online.ray.rayactor_group import RayGroup + + input_ids = input_ids.cuda() + attention_mask = attention_mask.cuda() + position_ids = position_ids.cuda() + + if isinstance(model, RayGroup): + records = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + #'use_cache': False + } + #TODO only one reference model -> maybe delete from group?? + logits = ray.get(model.async_forward(records))[0].logits[:, :-1] + else: + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ).logits[:, :-1] + logits /= self.args.temperature all_logprob = F.log_softmax(logits, dim=-1) - + print(f'{all_logprob.device=}, {input_ids.device=}') logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1) logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 diff --git a/turbo_alignment/trainers/online/reward_actor.py b/turbo_alignment/trainers/online/reward_actor.py index 29d77c4..3dbb629 100644 --- a/turbo_alignment/trainers/online/reward_actor.py +++ b/turbo_alignment/trainers/online/reward_actor.py @@ -1,6 +1,7 @@ import ray from turbo_alignment.trainers.online.ray.distributed_torch_ray_actor import DistributedTorchRayActor from transformers import AutoModelForSequenceClassification, AutoTokenizer +import torch @ray.remote(num_gpus=1) class RewardModel(DistributedTorchRayActor): @@ -11,8 +12,12 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): def init_model_from_pretrained(self, rm_model): self._setup_distributed() - self.model = AutoModelForSequenceClassification.from_pretrained(rm_model, device_map='cuda') + self.model = AutoModelForSequenceClassification.from_pretrained(rm_model, device_map='cuda', torch_dtype=torch.bfloat16) self.tokenizer = AutoTokenizer.from_pretrained(rm_model, trust_remote_code=True) + + self.model.config.pad_token_id = self.model.config.eos_token_id + self.tokenizer.pad_token = self.tokenizer.eos_token + print(f"Reward model initialized on Node {self.node_id}, Local Rank {self.local_rank}") print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"])) @@ -26,5 +31,11 @@ def generate(self, text: str): def eval(self): return self.model.eval() + @torch.no_grad def forward(self, x): + self.model.eval() + x = {k: v.cuda() for k, v in x.items()} + print('\n\n\nRM model', [v.shape for k, v in x.items()], 'RM model\n\n\n') + # print(self.tokenizer.decode(x['input_ids'][0], skip_special_tokens=False)) + #TODO reward from eos return self.model(**x) \ No newline at end of file