From 1d8beb0d7684c51eb3b44849b7231911f9dd1248 Mon Sep 17 00:00:00 2001 From: ummagumma Date: Wed, 28 Aug 2024 17:11:21 +0300 Subject: [PATCH] reinforce --- configs/exp/train/reinforce/reinforce.json | 230 ++++++ turbo_alignment/cli/train.py | 124 ++- turbo_alignment/common/vllm_utils.py | 111 +++ turbo_alignment/dataset/chat/chat.py | 219 +++-- turbo_alignment/dataset/chat/collators.py | 79 ++ turbo_alignment/pipelines/train/reinforce.py | 132 +++ .../settings/pipelines/__init__.py | 3 + .../settings/pipelines/train/reinforce.py | 102 +++ turbo_alignment/trainers/reinforce.py | 752 ++++++++++++++++++ turbo_alignment/vllm_server.py | 267 +++++++ 10 files changed, 1912 insertions(+), 107 deletions(-) create mode 100644 configs/exp/train/reinforce/reinforce.json create mode 100644 turbo_alignment/common/vllm_utils.py create mode 100644 turbo_alignment/dataset/chat/collators.py create mode 100644 turbo_alignment/pipelines/train/reinforce.py create mode 100644 turbo_alignment/settings/pipelines/train/reinforce.py create mode 100644 turbo_alignment/trainers/reinforce.py create mode 100644 turbo_alignment/vllm_server.py diff --git a/configs/exp/train/reinforce/reinforce.json b/configs/exp/train/reinforce/reinforce.json new file mode 100644 index 0000000..dc89434 --- /dev/null +++ b/configs/exp/train/reinforce/reinforce.json @@ -0,0 +1,230 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "train_hh", + "records_path": "/from_s3/datasets/train_chat_filtered.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 1024, + "only_answer_loss": true, + "keep_end": true, + "random_cut": false, + "check_max_len_enforced": true + }, + + "val_dataset_settings": { + "sources": [ + { + "name": "val_hh", + "records_path": "/from_s3/datasets/val_chat_filtered.jsonl", + "sample_rate": 0.2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 1024, + "only_answer_loss": true, + "keep_end": true, + "random_cut": false, + "check_max_len_enforced": true + }, + + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": true, + "num_return_sequences": 3, + "max_new_tokens": 1024, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 0, + "repetition_penalty": 1.0 + }, + "custom_generation_settings": { + "generation_eos_token": "<|eot_id|>", + "remove_prompt": true, + "skip_special_tokens": true + }, + "dataset_settings": { + "sources": [ + { + "name": "cp_hh", + "records_path": "/from_s3/datasets/val_chat_filtered.jsonl", + "num_samples": 10 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 1024, + "random_cut": false, + "keep_end": true, + "only_answer_loss": true, + "check_max_len_enforced": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "self_bleu", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "dist_n", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "diversity", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "perplexity", + "parameters": { + "need_average": [true, false] + } + } + ] + }, + "reward_model_settings": { + "model_path": "/from_s3/models/rm", + "model_type": "llama_reward_value", + "model_kwargs": { + "num_labels": 1, + "attn_implementation": "flash_attention_2" + }, + "transformers_settings": {}, + "check_checkpoint_is_compatible": true + }, + + "value_model_settings": { + "model_path": "/from_s3/models/rm", + "model_type": "seq_cls", + "model_kwargs": { + "num_labels": 1, + "attn_implementation": "flash_attention_2" + }, + "transformers_settings": {}, + "check_checkpoint_is_compatible": true + }, + "model_settings": { + "model_path": "/from_s3/models/policy/", + "model_type": "causal", + "transformers_settings": {}, + "model_kwargs": { + "attn_implementation": "flash_attention_2" + }, + "check_checkpoint_is_compatible": true + }, + + "tokenizer_settings": { + "use_fast": true, + "padding_side": "left" + }, + + "peft_settings": { + "inference_mode": false, + "r": 64, + "lora_alpha": 64, + "lora_dropout": 0.0, + "use_rslora": false, + "init_lora_weights": true, + "task_type": "CAUSAL_LM" + }, + "reinforce_class": "REINFORCETrainerRMRV", + + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 10, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_epsilon": 1e-5, + "eval_steps": 50, + "save_steps": 100, + "save_strategy": "steps", + "load_best_model_at_end": false, + "logging_steps": 1, + "learning_rate": 1e-5, + "max_steps": 1001, + "lr_scheduler_type": "constant_with_warmup", + "warmup_ratio": 0.03, + "fp16": false, + "bf16": true, + "optim": "adamw_torch", + "weight_decay": 0.0, + "max_grad_norm": 2, + "save_total_limit": 11, + "dataloader_num_workers": 12, + "deepspeed": "configs/exp/deepspeed/ds_config_stage_2_reinforce.json", + + "max_response_length": 1024, + "stop_token": "<|eot_id|>", + "temperature": 0.7, + "non_eos_penalty": true, + "penalty_reward_value": -1, + "clip_rewards_min": -1e8, + "clip_rewards_max": 1e8, + "whiten_rewards": false, + "kl_coef": 0.05, + "risk_coef": 0.00, + "mean_baseline_coef": 0.95, + + "init_kl_coef": null, + "target_kl": null, + "adaptive_kl_k": null, + "adaptive_kl_clip_value": null, + "min_kl_coef": null, + "max_kl_coef": null, + + "num_samples_for_reward_stats": 1000 + + }, + "wandb_settings": { + "project_name": "...", + "group": "hh-llama3-8b", + "job_type": "reinforce", + "run_name": "rvm-kl_0.05", + "entity": "...", + "tags": ["reinforce", "llama3-8b", "hh", "rvm"] + }, + "log_path": "train_output/hh/llama3-8b/reinforce/rvm-kl_0.05-0.01/" +} + diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index da109ff..034d0d6 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -7,99 +7,147 @@ from turbo_alignment.settings import pipelines as pipeline_settings -@app.command(name='train_sft', help='Run PEFT pipeline') +@app.command(name="train_sft", help="Run PEFT pipeline") def train_sft_entrypoint( experiment_settings_path: Path = typer.Option( ..., - '--experiment_settings_path', + "--experiment_settings_path", exists=True, - help='Path to experiment config file', - ) + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.SftTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = pipeline_settings.SftTrainExperimentSettings.parse_file( + experiment_settings_path + ) pipelines.TrainSFTStrategy().run(experiment_settings) -@app.command(name='train_dpo', help='Run DPO pipeline') +@app.command(name="train_dpo", help="Run DPO pipeline") def train_dpo_entrypoint( experiment_settings_path: Path = typer.Option( ..., - '--experiment_settings_path', + "--experiment_settings_path", exists=True, - help='Path to experiment config file', - ) + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.DPOTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = pipeline_settings.DPOTrainExperimentSettings.parse_file( + experiment_settings_path + ) pipelines.TrainDPOStrategy().run(experiment_settings) -@app.command(name='train_kto', help='Run KTO pipeline') +@app.command(name="train_kto", help="Run KTO pipeline") def train_kto_entrypoint( experiment_settings_path: Path = typer.Option( ..., - '--experiment_settings_path', + "--experiment_settings_path", exists=True, - help='Path to experiment config file', - ) + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.KTOTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = pipeline_settings.KTOTrainExperimentSettings.parse_file( + experiment_settings_path + ) pipelines.TrainKTOStrategy().run(experiment_settings) -@app.command(name='train_ddpo', help='Run DDPO pipeline') +@app.command(name="train_ddpo", help="Run DDPO pipeline") def train_ddpo_entrypoint( experiment_settings_path: Path = typer.Option( ..., - '--experiment_settings_path', + "--experiment_settings_path", exists=True, - help='Path to experiment config file', - ) + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.DDPOTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = pipeline_settings.DDPOTrainExperimentSettings.parse_file( + experiment_settings_path + ) pipelines.TrainDDPOStrategy().run(experiment_settings) -@app.command(name='train_rm', help='Run RM pipeline') +@app.command(name="train_rm", help="Run RM pipeline") def train_rm_entrypoint( experiment_settings_path: Path = typer.Option( ..., - '--experiment_settings_path', + "--experiment_settings_path", exists=True, - help='Path to experiment config file', - ) + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.RMTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = pipeline_settings.RMTrainExperimentSettings.parse_file( + experiment_settings_path + ) pipelines.TrainRMStrategy().run(experiment_settings) -@app.command(name='train_classification', help='Train Classifier') +@app.command(name="train_classification", help="Train Classifier") def classification_training( experiment_settings_path: Path = typer.Option( ..., - '--experiment_settings_path', + "--experiment_settings_path", exists=True, - help='Path to experiment config file', - ) + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.ClassificationTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = ( + pipeline_settings.ClassificationTrainExperimentSettings.parse_file( + experiment_settings_path + ) + ) pipelines.TrainClassificationStrategy().run(experiment_settings) -@app.command(name='train_multimodal', help='Train Multimodal') +@app.command(name="train_multimodal", help="Train Multimodal") def multimodal_training( experiment_settings_path: Path = typer.Option( - ..., '--experiment_settings_path', exists=True, help='Path to experiment config file' - ) + ..., + "--experiment_settings_path", + exists=True, + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.MultimodalTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = ( + pipeline_settings.MultimodalTrainExperimentSettings.parse_file( + experiment_settings_path + ) + ) pipelines.TrainMultimodalStrategy().run(experiment_settings) -@app.command(name='train_rag', help='Train RAG') +@app.command(name="train_rag", help="Train RAG") def rag_training( experiment_settings_path: Path = typer.Option( - ..., '--experiment_settings_path', exists=True, help='Path to experiment config file' - ) + ..., + "--experiment_settings_path", + exists=True, + help="Path to experiment config file", + ), ) -> None: - experiment_settings = pipeline_settings.RAGTrainExperimentSettings.parse_file(experiment_settings_path) + experiment_settings = pipeline_settings.RAGTrainExperimentSettings.parse_file( + experiment_settings_path + ) pipelines.TrainRAGStrategy().run(experiment_settings) + + +@app.command(name="train_reinforce", help="Train REINFORCE pipeline") +def reinforce_training( + experiment_settings_path: Path = typer.Option( + ..., + "--experiment_settings_path", + exists=True, + help="Path to experiment config file", + ), + num_servers: int = typer.Option( + ..., + "--num_servers", + exists=True, + help="Number of gpus dedicated to server", + ), +) -> None: + experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file( + experiment_settings_path + ) + experiment_settings.trainer_settings.num_servers = num_servers + pipelines.TrainREINFORCEStrategy().run(experiment_settings) diff --git a/turbo_alignment/common/vllm_utils.py b/turbo_alignment/common/vllm_utils.py new file mode 100644 index 0000000..8546d22 --- /dev/null +++ b/turbo_alignment/common/vllm_utils.py @@ -0,0 +1,111 @@ +from typing import List, Tuple + +import torch +import vllm +from transformers import AutoTokenizer + + +def pad_list( + tokenizer: AutoTokenizer, + tensor_list: List[torch.Tensor], + pad_side: str, + pad_value: int, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + old_padding_side = tokenizer.padding_side + old_pad_value = tokenizer.pad_token_id + tokenizer.padding_side = pad_side + tokenizer.pad_token_id = pad_value + + padout = tokenizer.pad( + {"input_ids": tensor_list}, + return_tensors="pt", + return_attention_mask=True, + max_length=max_length, + padding="max_length", + ) + + tokenizer.padding_side = old_padding_side + tokenizer.pad_token_id = old_pad_value + + return padout["input_ids"], padout["attention_mask"] + + +def vllm_generations_postprocess( + tokenizer: AutoTokenizer, + generations: List[vllm.outputs.RequestOutput], + max_length: int, + pad_to_max_len: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Queries+responses will go the reward model. Should be padded from the left. Output from the last token. + Queries+responses will go the reference model. Padding doesn't matter. Output from the response tokens. + """ + + # Merge queries and responses to the same sequences + # Create a mask showing where are the response tokens + query_responses = [] + response_tokens_mask = [] + # logprobs = [] + for g in generations: + for i in range(len(g.outputs)): + prompt_tokens = torch.tensor(g.prompt_token_ids) + response_tokens = torch.tensor(g.outputs[i].token_ids)[ + : max_length - len(prompt_tokens) + ] + qr = torch.cat( + [ + prompt_tokens, + response_tokens, + ], + ) + assert len(qr.shape) == 1, qr.shape + query_responses.append(qr) + + assert len(response_tokens) != 0, ( + len(response_tokens), + response_tokens, + g, + ) + assert len(prompt_tokens) != 0, ( + len(prompt_tokens), + prompt_tokens, + g, + ) + + r_attn = torch.cat( + [ + torch.zeros(len(prompt_tokens)), + torch.ones(len(response_tokens)), + ], + ) + + assert len(r_attn.shape) == 1, r_attn.shape + response_tokens_mask.append(r_attn) + + # Query-responses are padded to the same length + query_responses, attention_mask = pad_list( + tokenizer, + query_responses, + "left", + tokenizer.pad_token_id, + max_length=max_length, + ) + # as well as the response masks + response_tokens_mask, _ = pad_list( + tokenizer, + response_tokens_mask, + "left", + 0, + max_length=max_length, + ) + + position_ids = (attention_mask.cumsum(-1) - 1).clamp(min=0) + position_ids.masked_fill_(attention_mask.to(torch.bool) == 0, 0) + + return ( + query_responses, + attention_mask.to(torch.int32), + response_tokens_mask.to(torch.int32), + position_ids.to(torch.int32), + ) diff --git a/turbo_alignment/dataset/chat/chat.py b/turbo_alignment/dataset/chat/chat.py index a1409ba..92692fe 100755 --- a/turbo_alignment/dataset/chat/chat.py +++ b/turbo_alignment/dataset/chat/chat.py @@ -7,24 +7,16 @@ import numpy as np import numpy.typing as npt import torch +from src.common.data.io import read_jsonl +from src.common.logging import get_project_logger +from src.constants import DISABLE_LOSS_LABEL +from src.dataset.base import AlignmentDataset +from src.dataset.chat.models import ChatDatasetRecord, ChatMessage, ChatMessageRole +from src.dataset.registry import ChatDatasetTypeRegistry +from src.settings.datasets.base import DatasetSourceSettings, DatasetStrategy +from src.settings.datasets.chat import ChatDatasetSettings from transformers import PreTrainedTokenizerBase -from turbo_alignment.common.data.io import read_jsonl -from turbo_alignment.common.logging import get_project_logger -from turbo_alignment.constants import DISABLE_LOSS_LABEL -from turbo_alignment.dataset.base import AlignmentDataset -from turbo_alignment.dataset.chat.models import ( - ChatDatasetRecord, - ChatMessage, - ChatMessageRole, -) -from turbo_alignment.dataset.registry import ChatDatasetTypeRegistry -from turbo_alignment.settings.datasets.base import ( - DatasetSourceSettings, - DatasetStrategy, -) -from turbo_alignment.settings.datasets.chat import ChatDatasetSettings - from .conversation import Conversation logger = get_project_logger() @@ -45,12 +37,10 @@ def __init__( self._read() @overload - def __tokenize(self, inputs: str) -> npt.NDArray[np.int64]: - ... + def __tokenize(self, inputs: str) -> npt.NDArray[np.int64]: ... @overload - def __tokenize(self, inputs: list[str]) -> np.ndarray: - ... + def __tokenize(self, inputs: list[str]) -> np.ndarray: ... def __tokenize(self, inputs): return self.tokenizer( @@ -58,8 +48,8 @@ def __tokenize(self, inputs): add_special_tokens=False, padding=False, truncation=False, - return_tensors='np', - )['input_ids'] + return_tensors="np", + )["input_ids"] def __keep_end( self, @@ -85,7 +75,7 @@ def __keep_end( left_bound = i - raise ValueError('Can\'t trim dialogue to fit all requirements') + raise ValueError("Can't trim dialogue to fit all requirements") def __keep_start( self, @@ -94,7 +84,9 @@ def __keep_start( inference: bool, max_tokens: int | None = None, ) -> tuple[int, int]: - for i, (message, end_index) in enumerate(zip(conversation.messages[::-1], replicas_cum_len[::-1])): + for i, (message, end_index) in enumerate( + zip(conversation.messages[::-1], replicas_cum_len[::-1]) + ): if self.settings.only_answer_loss: if inference and message.role == ChatMessageRole.BOT: continue @@ -104,7 +96,7 @@ def __keep_start( if max_tokens is None or end_index < max_tokens: return 0, len(replicas_cum_len) - i - raise ValueError('Can\'t trim dialogue to fit all requirements') + raise ValueError("Can't trim dialogue to fit all requirements") def __truncate( self, @@ -113,15 +105,6 @@ def __truncate( inference: bool, max_tokens: int | None, ) -> tuple[int, int]: - ''' - truncate dialogue to fit all requirements: - - cumulative tokens num is less than max_tokens - - keep dialogue end or start according to settings - - remove bot's messages from dialog end at inference - - remove user's messages from dialog end at non inference (training) - - returns [first_message_index, last_message_index] - ''' if self.settings.keep_end: return self.__keep_end( conversation=conversation, @@ -156,7 +139,8 @@ def _truncate_and_merge( bot_prefix_tokens = role_prefix_tokens[ChatMessageRole.BOT] max_tokens = ( - self.settings.max_tokens_count - (len(bot_prefix_tokens) if inference else 1) + self.settings.max_tokens_count + - (len(bot_prefix_tokens) if inference else 1) if self.settings.max_tokens_count is not None else None ) @@ -174,16 +158,20 @@ def _truncate_and_merge( ) if not inference and right_bound - left_bound < 2: - raise ValueError('Less than two messages left after truncation') + raise ValueError("Less than two messages left after truncation") if ( inference and left_bound == 0 and right_bound == 1 # если при инференсе остался только системный промпт and conversation.messages[0].role == ChatMessageRole.SYSTEM ): - raise ValueError('Less than two messages left after truncation') - if not inference and self._all_loss_disabled(conversation.messages[left_bound:right_bound]): - raise ValueError('No messages with enabled loss were left after truncations') + raise ValueError("Less than two messages left after truncation") + if not inference and self._all_loss_disabled( + conversation.messages[left_bound:right_bound] + ): + raise ValueError( + "No messages with enabled loss were left after truncations" + ) if random_cut: bot_indices = [ @@ -195,27 +183,64 @@ def _truncate_and_merge( input_ids = np.array([self.tokenizer.bos_token_id]) labels = np.array([DISABLE_LOSS_LABEL]) + response_mask = np.array([0]) truncated_conversation_messages = conversation.messages[left_bound:right_bound] truncated_tokenized_replicas = tokenized_replicas[left_bound:right_bound] - if self.source.system_prompt is not None and self.settings.keep_end and left_bound != 0: - truncated_conversation_messages = [conversation.messages[0]] + truncated_conversation_messages - truncated_tokenized_replicas = [truncated_tokenized_replicas[0]] + truncated_tokenized_replicas + if ( + self.source.system_prompt is not None + and self.settings.keep_end + and left_bound != 0 + ): + truncated_conversation_messages = [ + conversation.messages[0] + ] + truncated_conversation_messages + truncated_tokenized_replicas = [ + truncated_tokenized_replicas[0] + ] + truncated_tokenized_replicas for ind, (message, tokenized_replica) in enumerate( zip( - truncated_conversation_messages[left_bound:right_bound], - truncated_tokenized_replicas[left_bound:right_bound], + truncated_conversation_messages, + truncated_tokenized_replicas, ) ): prefix_tokens = role_prefix_tokens[message.role] - merged_replica = np.concatenate((prefix_tokens, tokenized_replica, suffix_tokens)) + extra_tokens_len = ( + 1 + + len(prefix_tokens) + + len(suffix_tokens) + + (len(bot_prefix_tokens) if inference else 1) + ) + # If the last user's message is longer than max_tokens_count - + # slice the start of this message + # (assuming that the task is stated at the begging of the message. + # Seems to be true after a glance on several such messages) + if ( + False + and len(truncated_tokenized_replicas) == 1 + and len(tokenized_replica) + extra_tokens_len + > self.settings.max_tokens_count + ): + tokenized_replica = tokenized_replica[ + : self.settings.max_tokens_count - extra_tokens_len + ] + + merged_replica = np.concatenate( + (prefix_tokens, tokenized_replica, suffix_tokens) + ) input_ids = np.concatenate((input_ids, merged_replica)) if ( - (self.settings.only_last_replica_loss and ind != right_bound - left_bound - 1) - or (self.settings.only_answer_loss and message.role != ChatMessageRole.BOT) + ( + self.settings.only_last_replica_loss + and ind != right_bound - left_bound - 1 + ) + or ( + self.settings.only_answer_loss + and message.role != ChatMessageRole.BOT + ) or message.disable_loss ): replica_labels = np.full(merged_replica.shape, DISABLE_LOSS_LABEL) @@ -228,15 +253,50 @@ def _truncate_and_merge( ) ) labels = np.concatenate((labels, replica_labels)) + if ind == len(truncated_tokenized_replicas) - 1: + response_mask = np.concatenate( + [ + response_mask, + np.zeros_like(prefix_tokens), + np.ones_like(tokenized_replica), + np.ones_like(suffix_tokens), + ] + ) + else: + response_mask = np.concatenate( + [response_mask, np.zeros_like(merged_replica)] + ) if inference: input_ids = np.concatenate((input_ids, bot_prefix_tokens)) - labels = np.concatenate((labels, np.full(bot_prefix_tokens.shape, DISABLE_LOSS_LABEL))) + labels = np.concatenate( + (labels, np.full(bot_prefix_tokens.shape, DISABLE_LOSS_LABEL)) + ) + response_mask = np.concatenate( + [response_mask, np.zeros_like(bot_prefix_tokens)] + ) else: input_ids = np.append(input_ids, self.tokenizer.eos_token_id) labels = np.append(labels, DISABLE_LOSS_LABEL) + response_mask = np.append(response_mask, 0) + + assert len(input_ids) == len(labels) and len(labels) == len(response_mask), ( + len(input_ids), + len(labels), + len(response_mask), + ) + if self.settings.check_max_len_enforced: + assert len(input_ids) <= self.settings.max_tokens_count, ( + len(input_ids), + self.settings.max_tokens_count, + ) - return input_ids, labels, conversation.get_prompt_repr(left_bound, right_bound) + return ( + input_ids, + labels, + response_mask, + conversation.get_prompt_repr(left_bound, right_bound), + ) def _encode( self, @@ -244,15 +304,22 @@ def _encode( inference: bool, random_cut: bool, ) -> list[dict[str, Any] | None]: - conversations = [Conversation(system_prompt=self.source.system_prompt, messages=r.messages) for r in records] + conversations = [ + Conversation(system_prompt=self.source.system_prompt, messages=r.messages) + for r in records + ] - logger.info(f'Tokenizing dataset {self.source.name}') - tokenized_replicas = self.__tokenize([m.content for c in conversations for m in c.messages]) + logger.info(f"Tokenizing dataset {self.source.name}") + tokenized_replicas = self.__tokenize( + [m.content for c in conversations for m in c.messages] + ) tokenized_conversations = [] offset = 0 for c in conversations: - tokenized_conversations.append([tokenized_replicas[offset + i] for i in range(len(c.messages))]) + tokenized_conversations.append( + [tokenized_replicas[offset + i] for i in range(len(c.messages))] + ) offset += len(c.messages) role_prefix_tokens = { @@ -264,13 +331,17 @@ def _encode( for role in ChatMessageRole } # TODO: what if suffix is empty? - suffix_tokens = self.__tokenize(self.settings.prompt_template.suffix_template)[0] + suffix_tokens = self.__tokenize(self.settings.prompt_template.suffix_template)[ + 0 + ] - logger.info(f'Postprocessing tokenized data in {self.source.name}') + logger.info(f"Postprocessing tokenized data in {self.source.name}") output: list[dict[str, Any] | None] = [] - for record, conversation, tokenized_replicas in zip(records, conversations, tokenized_conversations): + for record, conversation, tokenized_replicas in zip( + records, conversations, tokenized_conversations + ): try: - input_ids, labels, prompt = self._truncate_and_merge( + input_ids, labels, response_mask, prompt = self._truncate_and_merge( conversation=conversation, tokenized_replicas=tokenized_replicas, role_prefix_tokens=role_prefix_tokens, @@ -281,18 +352,24 @@ def _encode( except ValueError as ex: output.append(None) - logger.warning(f'Sample dropped: {ex}') + logger.warning(f"Sample dropped: {ex}") continue encoded_record: dict[str, Any] = { # 'id': record.id, FIXME: dont work with collators - 'input_ids': torch.LongTensor(input_ids), - 'labels': torch.LongTensor(labels), - 'attention_mask': torch.ones(input_ids.shape, dtype=torch.int64), + "input_ids": torch.LongTensor(input_ids), + "labels": torch.LongTensor(labels), + "response_mask": torch.LongTensor(response_mask), + "attention_mask": torch.ones(input_ids.shape, dtype=torch.int64), } if inference: encoded_record.update( - {'prompt': prompt, 'id': record.id, 'messages': record.messages, 'meta': record.meta} + { + "prompt": prompt, + "id": record.id, + "messages": record.messages, + "meta": record.meta, + } ) output.append(encoded_record) @@ -301,13 +378,11 @@ def _encode( @staticmethod @overload - def _read_records(records: Path) -> list[ChatDatasetRecord]: - ... + def _read_records(records: Path) -> list[ChatDatasetRecord]: ... @staticmethod @overload - def _read_records(records: list[dict]) -> list[ChatDatasetRecord]: - ... + def _read_records(records: list[dict]) -> list[ChatDatasetRecord]: ... @staticmethod def _read_records(records) -> list[ChatDatasetRecord]: @@ -320,7 +395,9 @@ def _read_records(records) -> list[ChatDatasetRecord]: @ChatDatasetTypeRegistry.register(DatasetStrategy.TRAIN) class TrainChatDataset(ChatDataset): - def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]: + def convert_records( + self, records: list[ChatDatasetRecord] + ) -> list[dict[str, Any] | None]: return self._encode(records, inference=False, random_cut=False) @@ -336,7 +413,11 @@ def __init__( ) -> None: self._random_cut = random_cut - super().__init__(source=source, settings=settings, tokenizer=tokenizer, read=read) + super().__init__( + source=source, settings=settings, tokenizer=tokenizer, read=read + ) - def convert_records(self, records: list[ChatDatasetRecord]) -> list[dict[str, Any] | None]: + def convert_records( + self, records: list[ChatDatasetRecord] + ) -> list[dict[str, Any] | None]: return self._encode(records, inference=True, random_cut=self._random_cut) diff --git a/turbo_alignment/dataset/chat/collators.py b/turbo_alignment/dataset/chat/collators.py new file mode 100644 index 0000000..4ef72c5 --- /dev/null +++ b/turbo_alignment/dataset/chat/collators.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass +from typing import Any + +import torch +import transformers +from src.constants import DISABLE_LOSS_LABEL +from transformers import PreTrainedTokenizerBase + + +@dataclass +class REINFORCEDataCollator: + tokenizer: PreTrainedTokenizerBase + pad_to_multiple_of: int | None = None + + def _get_batch( + self, + examples: list[dict[str, dict[str, Any]]], + tokenizer: PreTrainedTokenizerBase, + max_length: int, + ) -> transformers.BatchEncoding: + features = [ex for ex in examples] + labels = [ + v.tolist() + for feature in features + for k, v in feature.items() + if k == "labels" + ] + no_labels_features = [ + { + k: v + for k, v in feature.items() + if k + not in ["prompt", "id", "messages", "meta", "labels", "response_mask"] + } + for feature in features + ] + + # print('🤫'*5) + # print(no_labels_features) + + batch = tokenizer.pad( + no_labels_features, + padding="max_length", + max_length=max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + if "response_mask" in examples[0].keys(): + padded_response_masks = [] + for feat in features: + x = feat["response_mask"] + padded_response_mask = torch.hstack( + [ + torch.zeros(max_length - len(x)), + x, + ] + ) + padded_response_masks.append(padded_response_mask) + + batch["response_mask"] = torch.vstack(padded_response_masks).to(x.dtype) + + batch["labels"] = torch.tensor( + [ + label + (max_length - len(label)) * [DISABLE_LOSS_LABEL] + for label in labels + ] + ) + + # print('👀'*5) + # print(batch) + return batch + + def __call__(self, examples: list[dict[str, dict[str, Any]]]) -> dict[str, Any]: + max_length = 0 + for ex in examples: + max_length = max(max_length, len(ex["input_ids"])) + + return dict(self._get_batch(examples, self.tokenizer, max_length)) diff --git a/turbo_alignment/pipelines/train/reinforce.py b/turbo_alignment/pipelines/train/reinforce.py new file mode 100644 index 0000000..68fcfba --- /dev/null +++ b/turbo_alignment/pipelines/train/reinforce.py @@ -0,0 +1,132 @@ +from typing import Callable + +from torch.utils.data import ConcatDataset, Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.data.data_collator import DataCollatorMixin + +import turbo_alignment.trainers.reinforce as reinforce_trainers +from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback +from turbo_alignment.common.logging import get_project_logger +from turbo_alignment.constants import TRAINER_LOGS_FOLDER +from turbo_alignment.dataset.chat.chat import InferenceChatDataset +from turbo_alignment.dataset.chat.collators import REINFORCEDataCollator +from turbo_alignment.dataset.loader import DatasetLoader +from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import MetricSettingsRegistry +from turbo_alignment.pipelines.train.base import BaseTrainStrategy +from turbo_alignment.settings.datasets.base import DatasetStrategy +from turbo_alignment.settings.pipelines.train.reinforce import ( + REINFORCETrainerSettings, + REINFORCETrainExperimentSettings, +) +from turbo_alignment.trainers.reinforce import REINFORCETrainingArguments + +logger = get_project_logger() + + +class TrainREINFORCEStrategy(BaseTrainStrategy[REINFORCETrainExperimentSettings]): + @staticmethod + def _get_data_collator( + experiment_settings: REINFORCETrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> Callable: + return REINFORCEDataCollator(tokenizer) + + @staticmethod + def _get_cherry_pick_callback( + experiment_settings: REINFORCETrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> ChatCherryPickCallback: + cherry_pick_settings = experiment_settings.cherry_pick_settings + + cherry_pick_datasets = DatasetLoader[InferenceChatDataset]( + InferenceChatDataset + ).load_datasets( + cherry_pick_settings.dataset_settings, + tokenizer=tokenizer, + strategy=DatasetStrategy.INFERENCE, + ) + + metrics = [ + Metric.by_name(metric.type)( + MetricSettingsRegistry.by_name(metric.type)(**metric.parameters) + ) + for metric in cherry_pick_settings.metric_settings + ] + + return ChatCherryPickCallback( + cherry_pick_settings=cherry_pick_settings, + datasets=cherry_pick_datasets, + metrics=metrics, + ) + + @staticmethod + def _get_training_args( + experiment_settings: REINFORCETrainExperimentSettings, + ) -> REINFORCETrainerSettings: + return REINFORCETrainingArguments( + output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), + label_names=[], + remove_unused_columns=False, + report_to=[], + **experiment_settings.trainer_settings.dict(), + ) + + @staticmethod + def _get_trainer( + training_args: REINFORCETrainerSettings, + experiment_settings: REINFORCETrainExperimentSettings, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + val_dataset: Dataset, + data_collator: Callable, + ): + cls = getattr(reinforce_trainers, experiment_settings.reinforce_class) + + return cls( + args=training_args, + tokenizer=tokenizer, + policy=model, + train_dataset=train_dataset, + eval_dataset=val_dataset, + data_collator=data_collator, + callbacks=[], + peft_config=experiment_settings.peft_settings.dict(), + policy_model_dir=experiment_settings.model_settings.model_path, + max_tokens_count=experiment_settings.train_dataset_settings.max_tokens_count, + ) + + def _dataset_and_collator_sanity_check( + self, dataset: Dataset, collator: DataCollatorMixin + ) -> None: + logger.info(f"Train sample input_ids:\n{dataset[0]}") + logger.info( + f'Train sample example:\n{self.tokenizer.decode(dataset[0]["input_ids"])}' + ) + + def _get_datasets( + self, experiment_settings: REINFORCETrainExperimentSettings + ) -> tuple[Dataset, Dataset]: + train_dataset: ConcatDataset = ConcatDataset( + datasets=DatasetLoader[InferenceChatDataset]( + InferenceChatDataset + ).load_datasets( + experiment_settings.train_dataset_settings, + tokenizer=self.tokenizer, + strategy=DatasetStrategy.INFERENCE, + ) + ) + + val_dataset: ConcatDataset = ConcatDataset( + datasets=DatasetLoader[InferenceChatDataset]( + InferenceChatDataset + ).load_datasets( + experiment_settings.val_dataset_settings, + tokenizer=self.tokenizer, + strategy=DatasetStrategy.INFERENCE, + ) + ) + return train_dataset, val_dataset diff --git a/turbo_alignment/settings/pipelines/__init__.py b/turbo_alignment/settings/pipelines/__init__.py index 1b6143a..0a4d7c0 100755 --- a/turbo_alignment/settings/pipelines/__init__.py +++ b/turbo_alignment/settings/pipelines/__init__.py @@ -14,5 +14,8 @@ MultimodalTrainExperimentSettings, ) from turbo_alignment.settings.pipelines.train.rag import RAGTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.reinforce import ( + REINFORCETrainExperimentSettings, +) from turbo_alignment.settings.pipelines.train.rm import RMTrainExperimentSettings from turbo_alignment.settings.pipelines.train.sft import SftTrainExperimentSettings diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py new file mode 100644 index 0000000..65b0003 --- /dev/null +++ b/turbo_alignment/settings/pipelines/train/reinforce.py @@ -0,0 +1,102 @@ +from typing import Literal, Optional + +from peft import TaskType + +from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel +from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings +from turbo_alignment.settings.datasets import ChatMultiDatasetSettings +from turbo_alignment.settings.model import ( + ModelForPeftSettings, + PreTrainedAdaptersModelSettings, + PreTrainedModelSettings, +) +from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings +from turbo_alignment.settings.tf.trainer import TrainerSettings + +# @dataclass +# class OnpolicyRuntimeConfig: +# # various batch sizes +# world_size: Optional[int] = 2 +# """The number of processes (GPUs) to use""" +# num_updates: Optional[int] = None +# """The number of updates to train""" +# micro_batch_size: Optional[int] = None +# """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" +# local_batch_size: Optional[int] = None +# """The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)""" +# batch_size: Optional[int] = None +# """The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)""" +# local_mini_batch_size: Optional[int] = None +# """the mini batch size per GPU""" +# mini_batch_size: Optional[int] = None +# """the mini batch size across GPUs""" + + +class REINFORCETrainerSettings(TrainerSettings): + max_response_length: Optional[int] = None + stop_token: Optional[str] = None + temperature: Optional[float] = None + penalty_reward_value: Optional[float] = None + clip_rewards_min: Optional[int] = None + clip_rewards_max: Optional[int] = None + """the reward value for responses that do not contain `stop_token_id`""" + non_eos_penalty: Optional[bool] = None + """whether to penalize responses that do not contain `stop_token_id`""" + + whiten_rewards: Optional[bool] = None + """whether to whiten the rewards""" + kl_coef: Optional[float] = None + """the KL coefficient""" + risk_coef: Optional[float] = None + mean_baseline_coef: Optional[float] = None + + # Adaptive KL rules + init_kl_coef: Optional[float] = None + target_kl: Optional[float] = None + adaptive_kl_k: Optional[float] = None + adaptive_kl_clip_value: Optional[float] = None + min_kl_coef: Optional[float] = None + max_kl_coef: Optional[float] = None + + num_generations: int = 1 + num_samples_for_reward_stats: int = 0 + + num_servers: Optional[int] = None + + +class PEFTSettings(ExtraFieldsNotAllowedBaseModel): + inference_mode: bool + r: int + lora_alpha: int + lora_dropout: float + use_rslora: bool + init_lora_weights: ( + bool | Literal["gaussian", "pissa", "pissa_niter_[number of iters]", "loftq"] + ) + task_type: TaskType + + +class REINFORCETrainExperimentSettings(BaseTrainExperimentSettings): + train_dataset_settings: ChatMultiDatasetSettings + val_dataset_settings: ChatMultiDatasetSettings + + cherry_pick_settings: ChatCherryPickSettings + + reward_model_settings: PreTrainedModelSettings | PreTrainedAdaptersModelSettings + value_model_settings: ( + ModelForPeftSettings | PreTrainedModelSettings | PreTrainedAdaptersModelSettings + ) + + trainer_settings: REINFORCETrainerSettings + peft_settings: PEFTSettings + + reinforce_class: Literal[ + "REINFORCETrainerVanillaRM", + "REINFORCETrainerCategoricalPRM", + "REINFORCETrainerCategoricalPRMVariancePenalty", + "REINFORCETrainerCategoricalPRMEntropyPenalty", + "REINFORCETrainerRMRV", + "REINFORCETrainerRMRVNoEMA", + "REINFORCETrainerRMRVNoValues", + "RLOOTrainer", + ] diff --git a/turbo_alignment/trainers/reinforce.py b/turbo_alignment/trainers/reinforce.py new file mode 100644 index 0000000..3c09e4b --- /dev/null +++ b/turbo_alignment/trainers/reinforce.py @@ -0,0 +1,752 @@ +import os +import pickle +import shutil +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import requests +import torch +import torch.nn.functional as F +import vllm +from datasets import Dataset +from peft import LoraConfig, get_peft_model +from torch import nn +from transformers import ( + PreTrainedModel, + PreTrainedTokenizer, + TrainerCallback, + TrainingArguments, +) + +from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer + +INVALID_LOGPROB = 1.0 + + +@dataclass +class REINFORCETrainingArguments(TrainingArguments): + max_response_length: Optional[int] = None + stop_token: Optional[str] = None + temperature: Optional[float] = None + penalty_reward_value: Optional[float] = None + clip_rewards_min: Optional[float] = None + clip_rewards_max: Optional[float] = None + """the reward value for responses that do not contain `stop_token_id`""" + non_eos_penalty: Optional[bool] = None + """whether to penalize responses that do not contain `stop_token_id`""" + + whiten_rewards: Optional[bool] = None + """whether to whiten the rewards""" + kl_coef: Optional[float] = None + """the KL coefficient""" + risk_coef: Optional[float] = None + mean_baseline_coef: Optional[float] = None + + # Adaptive KL rules + init_kl_coef: Optional[float] = None + target_kl: Optional[float] = None + adaptive_kl_k: Optional[float] = None + adaptive_kl_clip_value: Optional[float] = None + min_kl_coef: Optional[float] = None + max_kl_coef: Optional[float] = None + + num_generations: int = 1 + num_samples_for_reward_stats: int = 0 + + num_servers: Optional[int] = None + + +def get_global_mean(values: torch.Tensor): + # Calculate the mean reward for the current process + local_sum = values.sum().item() + num_rewards = torch.tensor(len(values), device=values.device) + + # Create a tensor to hold the global mean reward + global_sum = torch.tensor(local_sum, device=values.device) + + # Collect mean rewards from all processes + torch.distributed.all_reduce(global_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(num_rewards, op=torch.distributed.ReduceOp.SUM) + global_mean = (global_sum / num_rewards).item() + + return global_mean + + +def get_global_std(values: torch.Tensor, mean: float): + # Calculate the mean reward for the current process + local_sum = ((values - mean) ** 2).sum().item() + num_rewards = torch.tensor(len(values), device=values.device) + + # Create a tensor to hold the global mean reward + global_sum = torch.tensor(local_sum, device=values.device) + + # Collect mean rewards from all processes + torch.distributed.all_reduce(global_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(num_rewards, op=torch.distributed.ReduceOp.SUM) + global_mean = np.sqrt((global_sum / (num_rewards - 1)).item()) + + return global_mean + + +class AdaptiveKLController: + # https://arxiv.org/pdf/1909.08593 + def __init__( + self, + init_kl_coef: float, + target_kl: float, + k: float, + clip_value: float, + min_kl_coef: float, + max_kl_coef: float, + ): + self.value = init_kl_coef + self.target_kl = target_kl + self.k = k + self.clip_value = clip_value + self.min_kl_coef = min_kl_coef + self.max_kl_coef = max_kl_coef + + def update(self, current): + proportional_error = np.clip( + current.mean().cpu().item() / self.target_kl - 1, + -self.clip_value, + self.clip_value, + ) + mult = 1 + proportional_error * self.k + self.value = np.clip(self.value * mult, self.min_kl_coef, self.max_kl_coef) + + +class FixedKLController: + def __init__(self, value: float): + self.value = value + + def update(self, current): + pass + + +class PI_KLController: + def __init__(self, init_kl_coef: float, target_kl: float, kp: float, ki: float): + self.init_kl_coef = init_kl_coef + self.target_kl = target_kl + self.kp = kp + self.ki = ki + + def update(self, current): + error = self.target_kl - current + + +class REINFORCETrainer(MultiGPUCherryPicksTrainer): + def __init__( + self, + args: REINFORCETrainingArguments, + tokenizer: PreTrainedTokenizer, + policy: nn.Module, + train_dataset: Dataset, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + callbacks: list[TrainerCallback] | None = None, + peft_config: Dict[str, Any] = {}, + policy_model_dir: str = "", + max_tokens_count: int = -1, + **kwargs, + ) -> None: + self.args = args + self.policy_model_dir = policy_model_dir + self.max_tokens_count = max_tokens_count + + # Peft policy + peft_policy = get_peft_model(policy, peft_config=LoraConfig(**peft_config)) + peft_policy.generation_config.pad_token_id = tokenizer.pad_token_id + peft_policy.print_trainable_parameters() + + super().__init__( + model=peft_policy, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + callbacks=callbacks, + **kwargs, + ) + + # id of a stop token + self.stop_generation_token_id = tokenizer.encode( + self.args.stop_token, add_special_tokens=False + ) + assert len(self.stop_generation_token_id) == 1, self.stop_generation_token_id + + # TODO: remove logprobs + self.sampling_params = vllm.SamplingParams( + min_tokens=0, + max_tokens=args.max_response_length, + temperature=(args.temperature + 1e-7), + stop_token_ids=self.stop_generation_token_id, + n=args.num_generations, + ) + + self._stored_metrics: Dict[str, Dict[str, List[float]]] = defaultdict( + lambda: defaultdict(list) + ) + self.gen_id = 1 + self.mean_reward = None + + # Take gradient accumulation into account + # In this way, the effective number of gradient + # updates taken into account stays the same + effective_num_previous_samples = 1 / (1 - self.args.mean_baseline_coef) + self.args.mean_baseline_coef = 1 - 1 / ( + effective_num_previous_samples * self.args.gradient_accumulation_steps + ) + print("self.args.mean_baseline_coef", self.args.mean_baseline_coef) + + if self.args.target_kl is not None: + self.kl_controller = AdaptiveKLController( + init_kl_coef=args.init_kl_coef, + target_kl=args.target_kl, + k=args.adaptive_kl_k, + clip_value=args.adaptive_kl_clip_value, + min_kl_coef=args.min_kl_coef, + max_kl_coef=args.max_kl_coef, + ) + else: + self.kl_controller = FixedKLController(value=args.kl_coef) + + self.norm_reward_mean, self.norm_reward_std = self.reward_stats( + model=self.model, dataloader=self.get_train_dataloader() + ) + + def store_metrics( + self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + ) -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs: Dict[str, float]) -> None: + main_keys = list(self._stored_metrics.keys()) + for main_key in main_keys: + for key, metrics in self._stored_metrics[main_key].items(): + logs[key] = torch.tensor(metrics).float().cpu().mean().item() + + if main_key == "train": + logs["train/global_step"] = int(self.state.global_step) + else: + logs["eval/global_step"] = int( + self.state.global_step // self.args.eval_steps + ) + + del self._stored_metrics[main_key] + + return super().log(logs) # pylint: disable=no-member + + @torch.no_grad() + def reward_stats(self, model, dataloader): + if self.args.num_samples_for_reward_stats == 0: + norm_reward_mean = 0 + norm_reward_std = 1 + else: + all_rewards = [] + samples_processed = 0 + for batch in dataloader: + (_, _, _, _, rewards), _ = self.generate_responses( + peft_policy=self.accelerator.unwrap_model(model), + queries=batch["input_ids"], + attention_mask=batch["attention_mask"], + train_eval="eval", + ) + rewards, _ = self.postprocess_rewards(rewards=rewards) + + all_rewards.append(rewards) + samples_processed += ( + len(batch["input_ids"]) * self.accelerator.num_processes + ) + + if samples_processed > self.args.num_samples_for_reward_stats: + break + + all_rewards = torch.cat(all_rewards, 0) + + norm_reward_mean = get_global_mean(all_rewards) + norm_reward_std = get_global_std(all_rewards, mean=norm_reward_mean) + + return norm_reward_mean, norm_reward_std + + def batch_to_list(self, input_ids, attention_mask): + new_batch = [] + for ids, mask in zip(input_ids, attention_mask): + new_batch.append(torch.masked_select(ids, mask.to(torch.bool)).tolist()) + + return new_batch + + def generate_responses( + self, + peft_policy: torch.nn.Module, + queries: torch.Tensor, + attention_mask: torch.Tensor, + train_eval, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generates responses for a batch of queries. + For a better performance, uses vLLM. + """ + + # Dump lora weights to the disc. + # Later, this path will be passed to vllm generate. + self.accelerator.wait_for_everyone() + savedir = f"peft_policy_{self.gen_id}/" + if self.accelerator.is_main_process: + peft_policy.save_pretrained(savedir) + self.accelerator.wait_for_everyone() + + t0 = time.time() + prompt_token_ids = self.batch_to_list(queries, attention_mask) + batch_to_list_time = time.time() - t0 + + start_time = time.time() + pid = os.environ.get("LOCAL_RANK", -1) + port = 5000 + (int(pid) % self.args.num_servers) + response = requests.post( + f"http://127.0.0.1:{port}/generate", + data=pickle.dumps( + { + "prompt_token_ids": prompt_token_ids, + "sampling_params": self.sampling_params, + "lora_dir": savedir, + "lora_id": self.gen_id, + "max_length": self.max_tokens_count, + "train_eval": train_eval, + } + ), + headers={"Content-Type": "application/octet-stream"}, + timeout=100, + ) + self.accelerator.wait_for_everyone() + + generation_time = time.time() - start_time + if self.accelerator.is_main_process: + print("GENERATION TIME", generation_time) + + assert response.status_code == 200, ( + response.status_code, + response.headers["Content-Type"], + response.content, + ) + + ( + query_responses, + attention_mask, + response_tokens_mask, + position_ids, + rewards, + gen_time_on_server, + postprocess_time_on_server, + reward_time_on_server, + ) = pickle.loads(response.content) + + query_responses = query_responses.to(queries.device) + attention_mask = attention_mask.to(queries.device) + response_tokens_mask = response_tokens_mask.to(queries.device) + position_ids = position_ids.to(queries.device) + rewards = rewards.to(queries.device) + + assert torch.all( + torch.unique(attention_mask) + == torch.tensor([0, 1], device=attention_mask.device) + ) or torch.all( + torch.unique(attention_mask) + == torch.tensor([1], device=attention_mask.device) + ), attention_mask + assert torch.all( + torch.unique(response_tokens_mask) + == torch.tensor([0, 1], device=response_tokens_mask.device) + ), ( + response_tokens_mask.tolist(), + torch.unique(response_tokens_mask), + ) + assert query_responses.shape[1] <= self.max_tokens_count, ( + query_responses.shape, + self.max_tokens_count, + ) + assert ( + (query_responses.shape[1] == attention_mask.shape[1]) + and (query_responses.shape[1] == response_tokens_mask.shape[1]) + and (response_tokens_mask.shape[1] == position_ids.shape[1]) + ), ( + query_responses.shape, + attention_mask.shape, + response_tokens_mask.shape, + ) + assert ( + (query_responses.shape[0] == attention_mask.shape[0]) + and (query_responses.shape[0] == response_tokens_mask.shape[0]) + and (response_tokens_mask.shape[0] == rewards.shape[0]) + and (rewards.shape[0] == position_ids.shape[0]) + and (position_ids.shape[0] == (self.sampling_params.n * queries.shape[0])) + ), ( + query_responses.shape, + attention_mask.shape, + response_tokens_mask.shape, + rewards.shape, + queries.shape, + self.sampling_params.n, + ) + + # Remove the previous dump file if it exists + self.gen_id += 1 + if self.accelerator.is_main_process: + shutil.rmtree(savedir) + self.accelerator.wait_for_everyone() + + return ( + query_responses, + attention_mask, + response_tokens_mask, + position_ids, + rewards, + ), { + "generation_time": generation_time, + "query_resp_shape": query_responses.shape[1], + **self.get_log_mean_std( + response_tokens_mask.float().sum(-1), "resp_length" + ), + "gen_time_on_server": gen_time_on_server, + "postprocess_time_on_server": postprocess_time_on_server, + "reward_time_on_server": reward_time_on_server, + "batch_to_list_time": batch_to_list_time, + } + + def get_logprobs( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + loss_mask: torch.Tensor, + ): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ).logits[:, :-1] + logits /= self.args.temperature + + all_logprob = F.log_softmax(logits, dim=-1) + + logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze( + -1 + ) + logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 + + return logprob.sum(-1) + + def postprocess_rewards(self, rewards): + raise NotImplementedError + + def fill_nonvalid_rewards( + self, rewards, query_response + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.args.non_eos_penalty: + assert torch.all(query_response[:, -1] != self.tokenizer.pad_token_id), ( + query_response[:, -1], + self.tokenizer.pad_token_id, + ) + + invalid_mask = query_response[:, -1] != self.stop_generation_token_id[0] + rewards[invalid_mask] = self.args.penalty_reward_value + + return rewards, ~invalid_mask + + return rewards, torch.ones_like(rewards).to(torch.bool) + + def process_rewards(self, rewards, query_response): + """ + Second, this token should be the . + Otherwise, return a fixed reward of -1. + """ + + rewards, reward_metrics = self.postprocess_rewards(rewards=rewards) + rewards = rewards.squeeze(-1) + reward_metrics["normalizing_reward_mean"] = self.norm_reward_mean + reward_metrics["normalizing_reward_std"] = self.norm_reward_std + rewards = (rewards - self.norm_reward_mean) / self.norm_reward_std + # boundness is required: https://arxiv.org/pdf/2406.01462 + rewards = torch.clamp( + rewards, self.args.clip_rewards_min, self.args.clip_rewards_max + ) + + rewards, valid_mask = self.fill_nonvalid_rewards( + rewards=rewards, query_response=query_response + ) + + return rewards, valid_mask, reward_metrics + + @torch.no_grad() + def update_mean_reward(self, rewards): + global_mean_reward = get_global_mean(rewards) + + if self.mean_reward is None: + self.mean_reward = global_mean_reward + else: + self.mean_reward = ( + self.args.mean_baseline_coef * self.mean_reward + + (1 - self.args.mean_baseline_coef) * global_mean_reward + ) + + def baseline_rewards(self, rewards): + baseline = self.mean_reward if self.mean_reward is not None else 0 + advantages = rewards - baseline + + with torch.no_grad(): + metrics = {"baseline_mean": baseline} + + self.update_mean_reward(rewards=rewards) + + return advantages, metrics + + def get_log_mean_std(self, tensor, name, train_eval=None): + mean = get_global_mean(tensor) + metrics = {} + if train_eval is not None: + metrics[f"{train_eval}/{name}_mean"] = mean + metrics[f"{train_eval}/{name}_std"] = get_global_std(tensor, mean=mean) + else: + metrics[f"{name}_mean"] = mean + metrics[f"{name}_std"] = get_global_std(tensor, mean=mean) + + return metrics + + def _get_batch_loss_metrics( + self, model, inputs, train_eval: Literal["train", "eval"] = "train" + ): + assert inputs["input_ids"].shape[1] <= self.max_tokens_count, ( + f"Input length {inputs['input_ids'].shape[1]} exceeds max_tokens_count {self.max_tokens_count}", + self.tokenizer.decode(inputs["input_ids"][0]), + ) + + with torch.no_grad(): + ( + ( + query_response, + attention_mask, + response_tokens_mask, + position_ids, + rewards, + ), + gen_metrics, + ) = self.generate_responses( + peft_policy=self.accelerator.unwrap_model(model), + queries=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + train_eval=train_eval, + ) + + with model.disable_adapter(): + ref_logprobs = self.get_logprobs( + model=model, + input_ids=query_response, + attention_mask=attention_mask, + position_ids=position_ids, + loss_mask=response_tokens_mask.to(torch.float32), + ) + + rewards, valid_mask, rewards_metrics = self.process_rewards( + rewards=rewards, query_response=query_response + ) + + logprobs = self.get_logprobs( + model=model, + input_ids=query_response, + attention_mask=attention_mask, + position_ids=position_ids, + loss_mask=response_tokens_mask, + ) + + with torch.no_grad(): + kl_term = logprobs.detach() - ref_logprobs + regularized_rewards = rewards - self.kl_controller.value * kl_term + + baselined_reward, baseline_metrics = self.baseline_rewards( + rewards=regularized_rewards + ) + self.kl_controller.update(kl_term) + + loss = -baselined_reward * logprobs + + assert len(loss.shape) == 1, loss.shape + + with torch.no_grad(): + metrics = {} + for tensor, name in zip( + [ + rewards, + regularized_rewards, + baselined_reward, + loss.detach(), + kl_term.detach(), + logprobs.detach(), + 1 - valid_mask.float(), + ], + [ + "rewards", + "regularized_rewards", + "baselined_rewards", + "loss", + "kl_term", + "logprobs", + "invalid_rewards", + ], + ): + metrics = { + **metrics, + **self.get_log_mean_std(tensor, name, train_eval), + } + metrics[f"{train_eval}/kl_coef"] = self.kl_controller.value + + for k, v in rewards_metrics.items(): + metrics[f"{train_eval}/{k}"] = v + for k, v in gen_metrics.items(): + metrics[f"{train_eval}/{k}"] = v + for k, v in baseline_metrics.items(): + metrics[f"{train_eval}/{k}"] = v + + return loss.mean(), metrics + + def compute_loss(self, model, inputs, return_outputs: bool = False): + loss, metrics = self._get_batch_loss_metrics(model, inputs, "train") + + self.store_metrics(metrics=metrics, train_eval="train") + + return (loss, metrics) if return_outputs else loss + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + with torch.no_grad(): + loss, metrics = self._get_batch_loss_metrics(model, inputs, "eval") + + self.store_metrics(metrics, train_eval="eval") + + return loss.detach(), None, None + + +class REINFORCETrainerVanillaRM(REINFORCETrainer): + def postprocess_rewards(self, rewards): + return rewards, {**self.get_log_mean_std(rewards, "real_reward")} + + +class REINFORCETrainerCategoricalPRM(REINFORCETrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.support = torch.linspace(-0.5, 0.5, 51).unsqueeze(0) + + def postprocess_rewards(self, rewards): + probs = torch.nn.functional.softmax(rewards.to(torch.float32), dim=-1) + mean = torch.sum(probs * self.support.to(rewards.device), -1) + + return mean.unsqueeze(-1), { + **self.get_log_mean_std(mean, "real_reward"), + } + + +class REINFORCETrainerCategoricalPRMVariancePenalty(REINFORCETrainer): + def __init__(self, args, *aargs, **kwargs): + super().__init__(args=args, *aargs, **kwargs) + self.risk_coef = args.risk_coef + + self.support = torch.linspace(-0.5, 0.5, 51).unsqueeze(0) + + def postprocess_rewards(self, rewards): + probs = torch.nn.functional.softmax(rewards.to(torch.float32), dim=-1) + mean = torch.sum(probs * self.support.to(rewards.device), -1) + variance = torch.sum( + probs * (self.support.to(rewards.device) - mean.unsqueeze(-1)) ** 2, -1 + ) + + return (mean - self.risk_coef * variance).unsqueeze(-1), { + **self.get_log_mean_std(mean, "real_reward"), + **self.get_log_mean_std(variance, "variance"), + } + + +class REINFORCETrainerCategoricalPRMEntropyPenalty(REINFORCETrainer): + def __init__(self, args, *aargs, **kwargs): + super().__init__(args=args, *aargs, **kwargs) + self.risk_coef = args.risk_coef + + self.support = torch.linspace(-0.5, 0.5, 51).unsqueeze(0) + + def postprocess_rewards(self, rewards): + probs = torch.nn.functional.softmax(rewards.to(torch.float32), dim=-1) + mean = torch.sum(probs * self.support.to(rewards.device), -1) + entropy = -torch.sum(probs * torch.log(probs + 1e-9), -1) + + return (mean - self.risk_coef * entropy).unsqueeze(-1), { + **self.get_log_mean_std(mean, "real_reward"), + **self.get_log_mean_std(entropy, "entropy"), + } + + +class REINFORCETrainerRMRV(REINFORCETrainer): + def postprocess_rewards(self, rewards): + rewards, values = rewards[:, 0].unsqueeze(-1), rewards[:, 1].unsqueeze(-1) + with torch.no_grad(): + metrics = { + **self.get_log_mean_std(rewards, "real_reward"), + **self.get_log_mean_std(values, "real_value"), + **self.get_log_mean_std((rewards < values).float(), "frac_r= N: + condition.notify_all() + else: + condition.wait() + + response_event.wait() + return Response(request_entry["response"], content_type="application/octet-stream") + + +def save_rewards(rewards, step): + np.save( + os.path.join(rewards_savedir, f"rewards_{step}_on_port_{args.port}.npy"), + rewards.to(torch.float32).cpu().numpy(), + ) + + +@torch.inference_mode() +def batch_process(): + global request_queue, step # Add this line + while True: + with condition: + while len(request_queue) < N: + condition.wait() + + batch_data = request_queue[:N] + request_queue = request_queue[N:] + condition.notify_all() + + # Validate input data + if not all("prompt_token_ids" in entry["data"] for entry in batch_data): + continue # Skip this batch if validation fails + + prompt_token_ids_list = [] + for entry in batch_data: + prompt_token_ids_list.extend(entry["data"].get("prompt_token_ids", [])) + + sampling_params_list = [ + entry["data"].get("sampling_params", "") for entry in batch_data + ] + lora_dir_list = [entry["data"].get("lora_dir", []) for entry in batch_data] + max_length_list = [ + entry["data"].get("max_length", None) for entry in batch_data + ] + lora_id_list = [entry["data"].get("lora_id", None) for entry in batch_data] + train_eval_list = [ + entry["data"].get("train_eval", None) for entry in batch_data + ] + + assert all( + vars(param) == vars(sampling_params_list[0]) + for param in sampling_params_list + ), f"Not all sampling_params are the same: {sampling_params_list}" + assert all( + dir == lora_dir_list[0] for dir in lora_dir_list + ), f"Not all lora_dirs are the same: {lora_dir_list}" + assert all( + max_length == max_length_list[0] for max_length in max_length_list + ), f"Not all max_lengths are the same: {max_length_list}" + assert all( + lora_id == lora_id_list[0] for lora_id in lora_id_list + ), f"Not all lora_ids are the same: {lora_id_list}" + assert all( + train_eval == train_eval_list[0] for train_eval in train_eval_list + ), f"Not all train_evals are the same: {train_eval_list}" + + # no need to generate until the 'max_tokens' is reached + # what important is that the sum length of prompt and answer is less than 'max_tokens_count' + sampling_params = sampling_params_list[0] + sampling_params.max_tokens = min( + sampling_params.max_tokens, + max_length_list[0] - min([len(x) for x in prompt_token_ids_list]), + ) + print("num_prompts", len(prompt_token_ids_list)) + t0 = time.time() + generations = generating_policy.generate( + prompts=None, + prompt_token_ids=prompt_token_ids_list, + sampling_params=sampling_params, + lora_request=vllm.lora.request.LoRARequest( + lora_dir_list[0], lora_id_list[0], lora_dir_list[0] + ), + ) + t1 = time.time() + gen_time = t1 - t0 + # pad_to_max_len because I use torch.compile and I don't want + query_responses, attention_mask, response_tokens_mask, position_ids = ( + vllm_generations_postprocess(tokenizer, generations, max_length_list[0]) + ) + query_responses, attention_mask, response_tokens_mask, position_ids = ( + query_responses.cuda(), + attention_mask.cuda(), + response_tokens_mask.cuda(), + position_ids.cuda(), + ) + + t2 = time.time() + postprocess_time = t2 - t1 + + # the whole batch doesn't fit into memory + rewards = [] + mini_batch_size = len(query_responses) // 2 + for i in range(0, len(query_responses), mini_batch_size): + mini_batch_query_responses = query_responses[i : i + mini_batch_size] + mini_batch_attention_mask = attention_mask[i : i + mini_batch_size] + mini_batch_response_mask = response_tokens_mask[i : i + mini_batch_size] + mini_batch_position_ids = position_ids[i : i + mini_batch_size] + + mini_batch_rewards = reward_model( + input_ids=mini_batch_query_responses, + attention_mask=mini_batch_attention_mask, + response_mask=mini_batch_response_mask, + position_ids=mini_batch_position_ids, + )["logits"] + + rewards.append(mini_batch_rewards.cpu()) + + rewards = torch.cat(rewards, dim=0) + + if train_eval_list[0] == "train": + save_rewards(rewards, step) + step += 1 + + t3 = time.time() + reward_time = t3 - t2 + + # Group generations correctly + response_index = 0 + for entry in batch_data: + subbatch_size = ( + len(entry["data"].get("prompt_token_ids", [])) + * entry["data"].get("sampling_params", 1).n + ) + + entry["response"] = pickle.dumps( + ( + query_responses[ + response_index : response_index + subbatch_size + ].cpu(), + attention_mask[ + response_index : response_index + subbatch_size + ].cpu(), + response_tokens_mask[ + response_index : response_index + subbatch_size + ].cpu(), + position_ids[response_index : response_index + subbatch_size].cpu(), + rewards[response_index : response_index + subbatch_size].cpu(), + gen_time, + postprocess_time, + reward_time, + ) + ) + response_index += subbatch_size + entry["response_event"].set() + + with condition: + condition.notify_all() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="VLLM Server Batched") + parser.add_argument( + "--experiment_settings_path", + type=str, + help="Path to the experiment settings file", + ) + parser.add_argument( + "--total_num_gpus", + type=int, + help="Number of gpus used for training", + ) + parser.add_argument( + "--num_servers", + type=int, + help="Number of gpus dedicated for generation server", + ) + parser.add_argument( + "--port", + type=int, + help="Server Port", + ) + args = parser.parse_args() + + # num available gpus + num_gpus = torch.cuda.device_count() + print(f"Number of available GPUs: {num_gpus}") + + N = (args.total_num_gpus - args.num_servers) // args.num_servers + assert (args.total_num_gpus - args.num_servers) % args.num_servers == 0, ( + args.total_num_gpus, + args.num_servers, + ) + print("NUM_GPUS for policy training", N) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + policy_path = "/from_s3/models/policy" + generating_policy = vllm.LLM( + model=policy_path, + dtype="bfloat16", + enable_lora=True, + max_lora_rank=64, + skip_tokenizer_init=True, + max_model_len=2048, + max_seq_len_to_capture=2048, + gpu_memory_utilization=0.75, + ) + tokenizer = AutoTokenizer.from_pretrained(policy_path) + + experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file( + args.experiment_settings_path + ) + rewards_savedir = os.path.join(str(experiment_settings.log_path), "rewards") + os.makedirs(rewards_savedir, exist_ok=True) + + reward_model = ( + load_model( + model_settings=experiment_settings.reward_model_settings, + tokenizer=tokenizer, + ) + .cuda() + .eval() + ) + # reward_model.model = torch.compile(reward_model.model) + + request_queue = [] + condition = threading.Condition() + + threading.Thread(target=batch_process, daemon=True).start() + app.run(host="0.0.0.0", port=args.port, threaded=True)