Skip to content

Commit

Permalink
rm replicas added
Browse files Browse the repository at this point in the history
  • Loading branch information
thehir0 committed Dec 24, 2024
1 parent b94156e commit c0c3e14
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 11 deletions.
4 changes: 3 additions & 1 deletion turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def reinforce_training(
experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path)

policy_models = RayGroup(num_nodes=experiment_settings.trainer_settings.num_nodes, num_gpus_per_node=8, ray_actor_type=pipelines.TrainREINFORCEStrategy)
reward_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=RewardModel)

assert 1 <= experiment_settings.trainer_settings.reward_model_replicas <= 8
reward_model = RayGroup(num_nodes=1, num_gpus_per_node=experiment_settings.trainer_settings.reward_model_replicas, ray_actor_type=RewardModel)
# reference_model = RayGroup(num_nodes=1, num_gpus_per_node=1, ray_actor_type=ReferenceModel)

# TODO_RLOO if possible hide init inside RayGroup
Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/common/logging/clearml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

def create_clearml_task(parameters: ClearMLSettings, config: dict[str, Any] | None = None) -> Task:
clearml_task = Task.init(
task_name=parameters.task_name, project_name=parameters.project_name, continue_last_task=True # FIXME?
task_name=parameters.task_name,
project_name=parameters.project_name,
continue_last_task=True,
output_uri=False,
)

clearml_task.connect_configuration(config, name='HyperParameters')

return clearml_task
return clearml_task
4 changes: 2 additions & 2 deletions turbo_alignment/common/tf/special_tokens_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def setSEP(self, sep_token: str | None) -> None:
return None

def set_all(self) -> None:
# self.setBOS(bos_token=self._special_tokens_settings.bos_token)
self.setBOS(bos_token=self._special_tokens_settings.bos_token)
self.setEOS(eos_token=self._special_tokens_settings.eos_token)
self.setPAD(pad_token=self._special_tokens_settings.pad_token)
self.setUNK(unk_token=self._special_tokens_settings.unk_token)
Expand All @@ -94,7 +94,7 @@ def set_custom_tokens(self, tokens: list[str]) -> None:
assert added_tokens == len(tokens)

def setup_model_config(self, model: PreTrainedModel) -> None:
# model.config.bos_token_id = self._tokenizer.bos_token_id
model.config.bos_token_id = self._tokenizer.bos_token_id
model.config.eos_token_id = self._tokenizer.eos_token_id

if self._tokenizer.pad_token_id is not None:
Expand Down
4 changes: 2 additions & 2 deletions turbo_alignment/dataset/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def _truncate_and_merge(
input_ids = input_ids[cut_index:]

labels = labels[-len(input_ids) :]
# input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
# labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))
input_ids = np.concatenate((np.array([self.tokenizer.bos_token_id]), input_ids))
labels = np.concatenate((np.array([DISABLE_LOSS_LABEL]), labels))

return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound)

Expand Down
8 changes: 6 additions & 2 deletions turbo_alignment/generators/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@
import ray

class RayRMSamplingGenerator(BaseGenerator[SamplingDatasetRecord, RMSamplingInferenceOutput]):
def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, **kwargs):
def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, model_replicas: int = 1, **kwargs):
self._collator = DataCollatorWithPadding(tokenizer=tokenizer)
self._micro_batch = micro_batch
self.model_replicas = model_replicas
super().__init__(tokenizer=tokenizer, **kwargs)

def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEncoding) -> torch.Tensor:
with torch.no_grad():
# TODO assuming that only one reward model
records = {k: v.cuda() for k, v in records.items()}
rewards = ray.get(self._model.async_forward(records))[0]

rank = torch.distributed.get_rank()

rewards = ray.get(self._model.reward_forward(records=records, index=rank % self.model_replicas))
return rewards # .squeeze()

def generate_from_batch(
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/settings/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
class REINFORCETrainerSettings(TrainerSettings):

num_nodes: int = 2
reward_model_replicas: int = 1
max_new_tokens: int = 1024
stop_token: str = '<eos>'

Expand Down
5 changes: 5 additions & 0 deletions turbo_alignment/trainers/online/ray/rayactor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def reference_forward(self, records: dict[str, torch.Tensor], temperature, loss_
# TODO assuming one reference model
return self._actor_handlers[0].reference_forward.remote(records, temperature, loss_mask)

def reward_forward(self, records: dict[str, torch.Tensor], index: int):
assert index < len(self._actor_handlers), f'Reward model replicas: {len(self._actor_handlers)}, provided index: {index}'

return self._actor_handlers[index].forward.remote(records)

def async_forward(self, records: dict[str, torch.Tensor]):
return [actor.forward.remote(records) for actor in self._actor_handlers]

Expand Down
9 changes: 7 additions & 2 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def mean_std_format(values):
@dataclass
class REINFORCETrainingArguments(TrainingArguments):
num_nodes: int = 2
reward_model_replicas: int = 1
max_new_tokens: int = 1024
stop_token: str = '<eos>'

Expand Down Expand Up @@ -228,6 +229,9 @@ def __init__(
print("Generations Params:\n" + "\n".join([f"{attr}: {getattr(self.generator_transformers_settings, attr, None)}" for attr, _ in self.generator_transformers_settings.__annotations__.items()]))

start = time.time()

# if num_samples_for_reward_stats == 0 then no normalization is done

self.norm_reward_mean, self.norm_reward_std = self.reward_stats(
model=self.model, dataloader=self.get_train_dataloader()
)
Expand Down Expand Up @@ -268,6 +272,7 @@ def _get_rm_generator(self, reward_model: torch.nn.Module | PreTrainedModel) ->
generator = RayRMSamplingGenerator(
model=reward_model,
tokenizer=self.tokenizer,
model_replicas=self.args.reward_model_replicas,
)#TODO this type of critic is created for Reward models with CausalLM head and utilize vllm engines
case CriticType.DISTRIBUTED_VLLM:
generator = ...
Expand Down Expand Up @@ -527,8 +532,8 @@ def reward_stats(
self, model: torch.nn.Module | PreTrainedModel, dataloader: torch.utils.data.DataLoader
) -> tuple[float, float]:
if self.args.num_samples_for_reward_stats == 0:
norm_reward_mean = 0
norm_reward_std = 1
norm_reward_mean = 0.0
norm_reward_std = 1.0
else:
all_rewards = []
samples_processed = 0
Expand Down

0 comments on commit c0c3e14

Please sign in to comment.