Skip to content

Commit

Permalink
fixes 5
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Nov 9, 2024
1 parent 20cd742 commit bdfc40d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 16 deletions.
2 changes: 2 additions & 0 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def load_model(
**model_settings.model_kwargs,
torch_dtype=torch.bfloat16,
)
import ray
print(f'node_id:{ray.get_runtime_context().get_node_id()}, gpu_id: {ray.get_runtime_context().get_accelerator_ids()["GPU"]}', flush=True)

if model_settings.transformers_settings.load_in_8bit:
model = prepare_model_for_int8_training(model)
Expand Down
4 changes: 3 additions & 1 deletion turbo_alignment/generators/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, **k

def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEncoding) -> torch.Tensor:
with torch.no_grad():
rewards = ray.get(self._model.forward.remote(records).logits.cpu())
# TODO assuming that only one reward model
records = {k: v.cuda() for k, v in records.items()}
rewards = ray.get(self._model.async_forward(records))[0].logits
return rewards # .squeeze()

def generate_from_batch(
Expand Down
3 changes: 3 additions & 0 deletions turbo_alignment/trainers/online/ray/rayactor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def _initiate_actors(self, pg, num_gpus_per_actor):
).remote(world_size, rank, local_rank, master_addr, master_port)
self._actor_handlers.append(worker_actor)

def async_forward(self, records: dict[str, torch.Tensor]):
return [actor.forward.remote(records) for actor in self._actor_handlers]

def async_eval(self):
return [actor.eval.remote() for actor in self._actor_handlers]

Expand Down
6 changes: 5 additions & 1 deletion turbo_alignment/trainers/online/reference_actor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ray
from turbo_alignment.trainers.online.ray.distributed_torch_ray_actor import DistributedTorchRayActor
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

@ray.remote(num_gpus=1)
class ReferenceModel(DistributedTorchRayActor):
Expand All @@ -11,7 +12,7 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port):

def init_model_from_pretrained(self, pretrain):
self._setup_distributed()
self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda')
self.model = AutoModelForCausalLM.from_pretrained(pretrain, device_map='cuda', torch_dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True)
print(f"Reference model initialized on Node {self.node_id}, Local Rank {self.local_rank}")
print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"]))
Expand All @@ -26,5 +27,8 @@ def generate(self, text: str):
def eval(self):
return self.model.eval()

@torch.no_grad
def forward(self, x):
self.model.eval()
x = {k: v.cuda() for k, v in x.items()}
return self.model(**x)
45 changes: 32 additions & 13 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __init__(

# TODO_RLOO assert tp_size same for all engines
world_size = args.actor_settings["vllm_num_engines"] * args.actor_settings["vllm_tensor_parallel_size"] + 1

#TODO turn on nccl

backend = "nccl"
# https://github.com/OpenRLHF/OpenRLHF/issues/313
Expand Down Expand Up @@ -184,7 +186,7 @@ def __init__(
model=self.model, dataloader=self.get_train_dataloader()
)

# def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None) -> torch.Tensor:
# def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor:
# print('traning step', flush=True)
# print(f'reinforce.py: {type(model)=}', flush=True)
# assert True == False
Expand Down Expand Up @@ -225,7 +227,6 @@ def _get_rm_generator(self, reward_model: torch.nn.Module | PreTrainedModel) ->
generator = RayRMSamplingGenerator(
model=reward_model,
tokenizer=self.tokenizer,
accelerator=self.accelerator,
)#TODO this type of critic is created for Reward models with CausalLM head and utilize vllm engines
case CriticType.DISTRIBUTED_VLLM:
generator = ...
Expand Down Expand Up @@ -264,7 +265,8 @@ def get_answers_and_rewards(

if do_broadcast:
# TODO: move to generator
assert isinstance(model, DeepSpeedEngine)
print('broadcast', type(model), flush=True)
# assert isinstance(model, DeepSpeedEngine)
self._broadcast_to_vllm(model)
torch.distributed.barrier()

Expand Down Expand Up @@ -322,7 +324,7 @@ def get_batch_loss_metrics(
position_ids,
rewards,
) = self.get_answers_and_rewards(
model=self.accelerator.unwrap_model(model),
model=model,
inputs=inputs,
)

Expand Down Expand Up @@ -377,6 +379,7 @@ def get_batch_loss_metrics(
'invalid_rewards',
],
):
tensor = tensor.cuda()
metrics = {
**metrics,
**get_log_mean_std(tensor, name, train_eval),
Expand All @@ -390,7 +393,7 @@ def get_batch_loss_metrics(

return loss.mean(), metrics

def compute_loss(self, model, inputs, return_outputs: bool = False):
def compute_loss(self, model, inputs, return_outputs: bool = False, num_items_in_batch=None):
loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train')

self.store_metrics(metrics=metrics, train_eval='train')
Expand Down Expand Up @@ -450,16 +453,32 @@ def get_logprobs(
position_ids: torch.Tensor,
loss_mask: torch.Tensor,
):
logits = ray.get(model.forward.remote(
{'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
'use_cache': False,}
)).logits[:, :-1]
logits /= self.args.temperature
from turbo_alignment.trainers.online.ray.rayactor_group import RayGroup

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
position_ids = position_ids.cuda()

if isinstance(model, RayGroup):
records = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
#'use_cache': False
}
#TODO only one reference model -> maybe delete from group??
logits = ray.get(model.async_forward(records))[0].logits[:, :-1]
else:
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False,
).logits[:, :-1]

logits /= self.args.temperature
all_logprob = F.log_softmax(logits, dim=-1)

print(f'{all_logprob.device=}, {input_ids.device=}')
logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
logprob[~loss_mask[:, 1:].to(torch.bool)] = 0

Expand Down
13 changes: 12 additions & 1 deletion turbo_alignment/trainers/online/reward_actor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ray
from turbo_alignment.trainers.online.ray.distributed_torch_ray_actor import DistributedTorchRayActor
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

@ray.remote(num_gpus=1)
class RewardModel(DistributedTorchRayActor):
Expand All @@ -11,8 +12,12 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port):

def init_model_from_pretrained(self, rm_model):
self._setup_distributed()
self.model = AutoModelForSequenceClassification.from_pretrained(rm_model, device_map='cuda')
self.model = AutoModelForSequenceClassification.from_pretrained(rm_model, device_map='cuda', torch_dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(rm_model, trust_remote_code=True)

self.model.config.pad_token_id = self.model.config.eos_token_id
self.tokenizer.pad_token = self.tokenizer.eos_token

print(f"Reward model initialized on Node {self.node_id}, Local Rank {self.local_rank}")
print("GPU IDs: {}".format(ray.get_runtime_context().get_accelerator_ids()["GPU"]))

Expand All @@ -26,5 +31,11 @@ def generate(self, text: str):
def eval(self):
return self.model.eval()

@torch.no_grad
def forward(self, x):
self.model.eval()
x = {k: v.cuda() for k, v in x.items()}
print('\n\n\nRM model', [v.shape for k, v in x.items()], 'RM model\n\n\n')
# print(self.tokenizer.decode(x['input_ids'][0], skip_special_tokens=False))
#TODO reward from eos
return self.model(**x)

0 comments on commit bdfc40d

Please sign in to comment.