diff --git a/configs/exp/train/dpo/dpo.json b/configs/exp/train/dpo/dpo.json index ba6e9fe..52c8b83 100755 --- a/configs/exp/train/dpo/dpo.json +++ b/configs/exp/train/dpo/dpo.json @@ -100,7 +100,7 @@ "bos_token": "<|begin_of_text|>", "eos_token": "<|end_of_text|>" }, - "trainer_settings": { + "trainer_settings": { "evaluation_strategy": "steps", "per_device_train_batch_size": 1, "per_device_eval_batch_size": 1, diff --git a/configs/exp/train/reinforce/reinforce.json b/configs/exp/train/reinforce/reinforce.json new file mode 100644 index 0000000..1588bfc --- /dev/null +++ b/configs/exp/train/reinforce/reinforce.json @@ -0,0 +1,214 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "train_hh", + "records_path": "/from_s3/datasets/train_chat_filtered.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 1024, + "only_answer_loss": true, + "keep_end": true, + "random_cut": false, + "check_max_len_enforced": true + }, + "val_dataset_settings": { + "sources": [ + { + "name": "val_hh", + "records_path": "/from_s3/datasets/val_chat_filtered.jsonl", + "sample_rate": 0.2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 1024, + "only_answer_loss": true, + "keep_end": true, + "random_cut": false, + "check_max_len_enforced": true + }, + + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": true, + "num_return_sequences": 3, + "max_new_tokens": 1024, + "temperature": 0.7, + "top_p": 1.0, + "top_k": 0, + "repetition_penalty": 1.0 + }, + "custom_generation_settings": { + "generation_eos_token": "<|eot_id|>", + "remove_prompt": true, + "skip_special_tokens": true + }, + "dataset_settings": { + "sources": [ + { + "name": "cp_hh", + "records_path": "/from_s3/datasets/val_chat_filtered.jsonl", + "num_samples": 10 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 1024, + "random_cut": false, + "keep_end": true, + "only_answer_loss": true, + "check_max_len_enforced": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "self_bleu", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "dist_n", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "diversity", + "parameters": { + "need_average": [true, false] + } + }, + { + "type": "perplexity", + "parameters": { + "need_average": [true, false] + } + } + ] + }, + "reward_model_settings": { + "model_path": "/from_s3/models/rm", + "model_type": "llama_reward_value", + "model_kwargs": { + "num_labels": 1, + "attn_implementation": "flash_attention_2" + }, + "transformers_settings": {}, + "check_checkpoint_is_compatible": true + }, + "model_settings": { + "model_path": "/from_s3/models/policy/", + "model_type": "causal", + "transformers_settings": {}, + "model_kwargs": { + "attn_implementation": "flash_attention_2" + }, + "check_checkpoint_is_compatible": true + }, + "tokenizer_settings": { + "use_fast": true, + "padding_side": "left" + }, + "peft_settings": { + "inference_mode": false, + "r": 64, + "lora_alpha": 64, + "lora_dropout": 0.0, + "use_rslora": false, + "init_lora_weights": true, + "task_type": "CAUSAL_LM" + }, + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 10, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_epsilon": 1e-5, + "eval_steps": 50, + "save_steps": 100, + "save_strategy": "steps", + "load_best_model_at_end": false, + "logging_steps": 1, + "learning_rate": 1e-5, + "max_steps": 1001, + "lr_scheduler_type": "constant_with_warmup", + "warmup_ratio": 0.03, + "fp16": false, + "bf16": true, + "optim": "adamw_torch", + "weight_decay": 0.0, + "max_grad_norm": 2, + "save_total_limit": 11, + "dataloader_num_workers": 12, + "deepspeed": "configs/exp/deepspeed/ds_config_stage_2_reinforce.json", + + "max_response_length": 1024, + "stop_token": "<|eot_id|>", + "temperature": 0.7, + "non_eos_penalty": true, + "penalty_reward_value": -1, + "clip_rewards_min": -1e8, + "clip_rewards_max": 1e8, + "whiten_rewards": false, + "kl_coef": 0.05, + "risk_coef": 0.00, + "mean_baseline_coef": 0.95, + + "init_kl_coef": null, + "target_kl": null, + "adaptive_kl_k": null, + "adaptive_kl_clip_value": null, + "min_kl_coef": null, + "max_kl_coef": null, + + "num_samples_for_reward_stats": 1000 + + }, + "wandb_settings": { + "project_name": "...", + "group": "hh-llama3-8b", + "job_type": "reinforce", + "run_name": "rvm-kl_0.05", + "entity": "...", + "tags": ["reinforce", "llama3-8b", "hh", "rvm"] + }, + "log_path": "train_output/hh/llama3-8b/reinforce/rvm-kl_0.05-0.01/" +} + diff --git a/tests/fixtures/configs/train/reinforce/base.json b/tests/fixtures/configs/train/reinforce/base.json new file mode 100644 index 0000000..520ced7 --- /dev/null +++ b/tests/fixtures/configs/train/reinforce/base.json @@ -0,0 +1,175 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "train_chat", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 2000, + "only_answer_loss": true + }, + "val_dataset_settings": { + "sources": [ + { + "name": "val_chat", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 2000, + "only_answer_loss": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { + "sources": [ + { + "name": "chat_test", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "num_samples": 2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 150, + "only_answer_loss": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": { + "need_average": [ + true + ] + } + }, + { + "type": "kl", + "parameters": { + "need_average": [ + true + ], + "ref_logits_type": "sft" + } + }, + { + "type": "kl", + "parameters": { + "need_average": [ + true + ], + "ref_logits_type": "reference" + } + } + ] + }, + "reward_model_settings": { + "model_path": "tests/fixtures/models/llama2_tiny", + "model_type": "seq_cls", + "model_kwargs": { + "num_labels": 1 + }, + "transformers_settings": {} + }, + "model_settings": { + "model_path": "tests/fixtures/models/llama2_tiny", + "model_type": "causal", + "transformers_settings": {}, + "embeddings_initialization_strategy": { + "<|start_header_id|>": "", + "<|end_header_id|>": "", + "<|eot_id|>": "" + } + }, + "tokenizer_settings": { + "tokenizer_kwargs": { + "padding_side": "left" + } + }, + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 10, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_epsilon": 0.00001, + "eval_steps": 50, + "save_steps": 100, + "save_strategy": "steps", + "load_best_model_at_end": false, + "logging_steps": 1, + "learning_rate": 0.00001, + "max_steps": 1001, + "lr_scheduler_type": "constant_with_warmup", + "warmup_ratio": 0.03, + "fp16": false, + "bf16": false, + "optim": "adamw_torch", + "weight_decay": 0.0, + "max_grad_norm": 2, + "save_total_limit": 11, + "dataloader_num_workers": 12, + "stop_token": "<|eot_id|>", + "temperature": 0.7, + "non_eos_penalty": true, + "penalty_reward_value": -1, + "clip_rewards_min": -1e+8, + "clip_rewards_max": 1e+8, + "whiten_rewards": false, + "kl_coef": 0.05, + "mean_baseline_coef": 0.95, + "num_samples_for_reward_stats": 1000, + "no_cuda": true + }, + "special_tokens_settings": { + "bos_token": "", + "eos_token": "", + "pad_token": "" + }, + "logging_settings": { + "project_name": "alignment", + "run_name": "sft", + "entity": "turbo-alignment" + }, + "seed": 0, + "log_path": "train_output" +} diff --git a/tests/fixtures/configs/train/reinforce/gpt2_base.json b/tests/fixtures/configs/train/reinforce/gpt2_base.json new file mode 100644 index 0000000..7560fb9 --- /dev/null +++ b/tests/fixtures/configs/train/reinforce/gpt2_base.json @@ -0,0 +1,175 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "train_chat", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 2000, + "only_answer_loss": true + }, + "val_dataset_settings": { + "sources": [ + { + "name": "val_chat", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "sample_rate": 1.0 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 2000, + "only_answer_loss": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { + "sources": [ + { + "name": "chat_test", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "num_samples": 2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "assistant", + "user": "user", + "system": "system" + }, + "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n", + "suffix_template": "<|eot_id|>" + }, + "dataset_type": "chat", + "max_tokens_count": 150, + "only_answer_loss": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": { + "need_average": [ + true + ] + } + }, + { + "type": "kl", + "parameters": { + "need_average": [ + true + ], + "ref_logits_type": "sft" + } + }, + { + "type": "kl", + "parameters": { + "need_average": [ + true + ], + "ref_logits_type": "reference" + } + } + ] + }, + "reward_model_settings": { + "model_path": "gpt2", + "model_type": "seq_cls", + "model_kwargs": { + "num_labels": 1 + }, + "transformers_settings": {} + }, + "model_settings": { + "model_path": "gpt2", + "model_type": "causal", + "transformers_settings": {}, + "embeddings_initialization_strategy": { + "<|start_header_id|>": "<|endoftext|>", + "<|end_header_id|>": "<|endoftext|>", + "<|eot_id|>": "<|endoftext|>" + } + }, + "tokenizer_settings": { + "tokenizer_kwargs": { + "padding_side": "left" + } + }, + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "gradient_accumulation_steps": 1, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_epsilon": 0.00001, + "eval_steps": 50, + "save_steps": 100, + "save_strategy": "steps", + "load_best_model_at_end": false, + "logging_steps": 1, + "learning_rate": 0.00001, + "max_steps": 1001, + "lr_scheduler_type": "constant_with_warmup", + "warmup_ratio": 0.03, + "fp16": false, + "bf16": false, + "optim": "adamw_torch", + "weight_decay": 0.0, + "max_grad_norm": 2, + "save_total_limit": 11, + "dataloader_num_workers": 12, + "stop_token": "<|eot_id|>", + "temperature": 0.7, + "non_eos_penalty": true, + "penalty_reward_value": -1, + "clip_rewards_min": -1e+8, + "clip_rewards_max": 1e+8, + "whiten_rewards": false, + "kl_coef": 0.05, + "mean_baseline_coef": 0.95, + "num_samples_for_reward_stats": 10, + "no_cuda": true + }, + "special_tokens_settings": { + "bos_token": "<|endoftext|>", + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>" + }, + "logging_settings": { + "project_name": "alignment", + "run_name": "sft", + "entity": "turbo-alignment" + }, + "seed": 0, + "log_path": "train_output" +} diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py index 80c6022..ba1c536 100755 --- a/turbo_alignment/cherry_picks/multimodal.py +++ b/turbo_alignment/cherry_picks/multimodal.py @@ -1,10 +1,10 @@ from typing import Iterable import torch +import wandb from PIL import Image from transformers import PreTrainedTokenizerBase -import wandb from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset from turbo_alignment.generators.multimodal import MultimodalGenerator diff --git a/turbo_alignment/cherry_picks/rm.py b/turbo_alignment/cherry_picks/rm.py index 6cdcbcc..a86d7da 100755 --- a/turbo_alignment/cherry_picks/rm.py +++ b/turbo_alignment/cherry_picks/rm.py @@ -5,7 +5,7 @@ from turbo_alignment.cherry_picks.base import CherryPickCallbackBase from turbo_alignment.dataset.pair_preferences import PairPreferenceDataset -from turbo_alignment.generators.rm import RMPairGenerator +from turbo_alignment.generators import RMSamplingGenerator from turbo_alignment.metrics.metric import Metric from turbo_alignment.settings.cherry_pick import RMCherryPickSettings from turbo_alignment.settings.metric import ElementWiseScores, MetricResults @@ -29,7 +29,8 @@ def _get_dataset_metrics( ) -> list[MetricResults]: accelerator: Accelerator = kwargs.get('accelerator', None) - generator = RMPairGenerator( + # FIXME: convert paired dataset to sampling dataset + generator = RMSamplingGenerator( model=model, tokenizer=tokenizer, accelerator=accelerator, diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py index da109ff..aa522ad 100755 --- a/turbo_alignment/cli/train.py +++ b/turbo_alignment/cli/train.py @@ -14,7 +14,7 @@ def train_sft_entrypoint( '--experiment_settings_path', exists=True, help='Path to experiment config file', - ) + ), ) -> None: experiment_settings = pipeline_settings.SftTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainSFTStrategy().run(experiment_settings) @@ -27,7 +27,7 @@ def train_dpo_entrypoint( '--experiment_settings_path', exists=True, help='Path to experiment config file', - ) + ), ) -> None: experiment_settings = pipeline_settings.DPOTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainDPOStrategy().run(experiment_settings) @@ -40,7 +40,7 @@ def train_kto_entrypoint( '--experiment_settings_path', exists=True, help='Path to experiment config file', - ) + ), ) -> None: experiment_settings = pipeline_settings.KTOTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainKTOStrategy().run(experiment_settings) @@ -53,7 +53,7 @@ def train_ddpo_entrypoint( '--experiment_settings_path', exists=True, help='Path to experiment config file', - ) + ), ) -> None: experiment_settings = pipeline_settings.DDPOTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainDDPOStrategy().run(experiment_settings) @@ -66,7 +66,7 @@ def train_rm_entrypoint( '--experiment_settings_path', exists=True, help='Path to experiment config file', - ) + ), ) -> None: experiment_settings = pipeline_settings.RMTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainRMStrategy().run(experiment_settings) @@ -79,7 +79,7 @@ def classification_training( '--experiment_settings_path', exists=True, help='Path to experiment config file', - ) + ), ) -> None: experiment_settings = pipeline_settings.ClassificationTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainClassificationStrategy().run(experiment_settings) @@ -88,8 +88,11 @@ def classification_training( @app.command(name='train_multimodal', help='Train Multimodal') def multimodal_training( experiment_settings_path: Path = typer.Option( - ..., '--experiment_settings_path', exists=True, help='Path to experiment config file' - ) + ..., + '--experiment_settings_path', + exists=True, + help='Path to experiment config file', + ), ) -> None: experiment_settings = pipeline_settings.MultimodalTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainMultimodalStrategy().run(experiment_settings) @@ -98,8 +101,24 @@ def multimodal_training( @app.command(name='train_rag', help='Train RAG') def rag_training( experiment_settings_path: Path = typer.Option( - ..., '--experiment_settings_path', exists=True, help='Path to experiment config file' - ) + ..., + '--experiment_settings_path', + exists=True, + help='Path to experiment config file', + ), ) -> None: experiment_settings = pipeline_settings.RAGTrainExperimentSettings.parse_file(experiment_settings_path) pipelines.TrainRAGStrategy().run(experiment_settings) + + +@app.command(name='train_reinforce', help='Train REINFORCE pipeline') +def reinforce_training( + experiment_settings_path: Path = typer.Option( + ..., + '--experiment_settings_path', + exists=True, + help='Path to experiment config file', + ), +) -> None: + experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path) + pipelines.TrainREINFORCEStrategy().run(experiment_settings) diff --git a/turbo_alignment/common/data/multimodal/audio/imagebind.py b/turbo_alignment/common/data/multimodal/audio/imagebind.py index c27503b..e9da72b 100644 --- a/turbo_alignment/common/data/multimodal/audio/imagebind.py +++ b/turbo_alignment/common/data/multimodal/audio/imagebind.py @@ -1,7 +1,12 @@ from pathlib import Path import torch -import torchaudio + +# FIXME +try: + import torchaudio +except: + pass from loguru import logger from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler from torchvision import transforms diff --git a/turbo_alignment/common/distributed.py b/turbo_alignment/common/distributed.py new file mode 100644 index 0000000..a81c3cf --- /dev/null +++ b/turbo_alignment/common/distributed.py @@ -0,0 +1,128 @@ +from datetime import timedelta +from typing import Any, Literal + +import numpy as np +import torch +import os +import torch.distributed +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) + + +world_size = int(os.getenv("WORLD_SIZE", "1")) + + +def get_global_mean(values: torch.Tensor) -> float: + # Calculate the mean reward for the current process + local_sum = values.sum().item() + + if world_size == 1: + return values.mean().item() + + num_rewards = torch.tensor(len(values), device=values.device) + + # Create a tensor to hold the global mean reward + global_sum = torch.tensor(local_sum, device=values.device) + + # Collect mean rewards from all processes + torch.distributed.all_reduce(global_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(num_rewards, op=torch.distributed.ReduceOp.SUM) + global_mean = (global_sum / num_rewards).item() + + return global_mean + + +def get_global_std(values: torch.Tensor, mean: float) -> float: + local_sum = ((values - mean) ** 2).sum().item() + + if world_size == 1: + return local_sum + + # Calculate the mean reward for the current process + num_rewards = torch.tensor(len(values), device=values.device) + + # Create a tensor to hold the global mean reward + global_sum = torch.tensor(local_sum, device=values.device) + + # Collect mean rewards from all processes + torch.distributed.all_reduce(global_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(num_rewards, op=torch.distributed.ReduceOp.SUM) + global_mean = np.sqrt((global_sum / (num_rewards - 1)).item()) + + return global_mean + + +def get_log_mean_std( + tensor: torch.Tensor, name: str, train_eval: Literal['train', 'eval'] | None = None +) -> dict[str, float]: + mean = get_global_mean(tensor) + metrics = {} + if train_eval is not None: + metrics[f'{train_eval}/{name}_mean'] = mean + metrics[f'{train_eval}/{name}_std'] = get_global_std(tensor, mean=mean) + else: + metrics[f'{name}_mean'] = mean + metrics[f'{name}_std'] = get_global_std(tensor, mean=mean) + + return metrics + + +# Copy from pytorch to allow creating multiple main groups. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py +def init_process_group( + backend: str | Backend | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str | None = None, + pg_options: Any = None, +): + assert (store is None) or (init_method is None), 'Cannot specify both init_method and store.' + + if store is not None: + assert world_size > 0, 'world_size must be positive if using store' + assert rank >= 0, 'rank must be non-negative if using store' + elif init_method is None: + init_method = 'env://' + + if backend: + backend = Backend(backend) + else: + backend = Backend('undefined') + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + pg_options=pg_options, + timeout=timeout, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/turbo_alignment/common/logging/weights_and_biases.py b/turbo_alignment/common/logging/weights_and_biases.py index 112e585..3b5bdfd 100755 --- a/turbo_alignment/common/logging/weights_and_biases.py +++ b/turbo_alignment/common/logging/weights_and_biases.py @@ -1,9 +1,9 @@ from typing import Any +import wandb from wandb.sdk.lib.disabled import RunDisabled from wandb.sdk.wandb_run import Run -import wandb from turbo_alignment.settings.logging.weights_and_biases import WandbSettings diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py index 1ad484d..646e280 100755 --- a/turbo_alignment/common/tf/callbacks/logging.py +++ b/turbo_alignment/common/tf/callbacks/logging.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import wandb from clearml import Task from transformers import ( TrainerCallback, @@ -13,7 +14,6 @@ from wandb.sdk.lib.disabled import RunDisabled from wandb.sdk.wandb_run import Run -import wandb from turbo_alignment.common.logging import get_project_logger logger = get_project_logger() diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 5415daf..86ddcfd 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -4,7 +4,12 @@ import torch from accelerate import Accelerator -from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase +from transformers import ( + BatchEncoding, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) from turbo_alignment.dataset.base import BaseDataset from turbo_alignment.dataset.base.models import DatasetRecord @@ -34,11 +39,11 @@ def __init__( self._batch = batch @abstractmethod - def _generate_from_batch( + def generate_from_batch( self, - records: list[dict[str, Any]], - original_records: list[DatasetRecordT], dataset_name: str, + records: list[dict[str, Any]], + original_records: list[DatasetRecordT] | None = None, ) -> list[InferenceOutputT]: ... @@ -64,10 +69,10 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]: zip(records_batches, original_records_batches) ): generations.append( - self._generate_from_batch( + self.generate_from_batch( + dataset.source.name, records_batch, original_records_batch, - dataset.source.name, ) ) else: @@ -75,10 +80,10 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]: list(zip(records_batches, original_records_batches)), apply_padding=True ) as accelerator_records: generations = [ - self._generate_from_batch( + self.generate_from_batch( + dataset.source.name, records_batch, original_records_batch, - dataset.source.name, ) for records_batch, original_records_batch in accelerator_records ][: len(records_batches)] @@ -92,13 +97,10 @@ def __init__( transformers_settings: GeneratorTransformersSettings, custom_generation_settings: CustomChatGenerationSettings, tokenizer: PreTrainedTokenizerBase, - return_logits: bool = False, **kwargs, ) -> None: super().__init__(tokenizer=tokenizer, **kwargs) - self._return_logits = return_logits - self._transformers_generator_parameters = GenerationConfig( bos_token_id=self._tokenizer.bos_token_id, **transformers_settings.dict(), @@ -107,25 +109,28 @@ def __init__( self._custom_generation_settings = custom_generation_settings @abstractmethod - def _generate_from_single_record( + def generate_from_single_record( self, - record: dict[str, torch.Tensor], - original_record: DatasetRecordT, dataset_name: str, + record: dict[str, torch.Tensor], + original_record: DatasetRecordT | None = None, ) -> InferenceOutputT: ... @abstractmethod - def _generate_from_batch_records( + def generate_from_batch_records( self, - records: list[dict[str, torch.Tensor]], - original_records: list[DatasetRecordT], dataset_name: str, + records_batch: dict[str, torch.Tensor] | BatchEncoding, + original_records: list[DatasetRecordT] | None = None, ) -> list[InferenceOutputT]: ... - def _generate_from_batch( - self, records: list[dict[str, Any]], original_records: list[DatasetRecordT], dataset_name: str + def generate_from_batch( + self, + dataset_name: str, + records: list[dict[str, Any]], + original_records: list[DatasetRecordT] | None = None, ) -> list[InferenceOutputT]: if self._custom_generation_settings.batch > 1: if self._transformers_generator_parameters.num_beams != 1: @@ -134,18 +139,52 @@ def _generate_from_batch( self._tokenizer.padding_side = 'left' self._tokenizer.pad_token_id = self._tokenizer.pad_token_id - return self._generate_from_batch_records(records, original_records, dataset_name) + input_ids = [record['input_ids'].tolist() for record in records] + attention_mask = [record['attention_mask'].tolist() for record in records] + + max_input_length = max(len(sample) for sample in input_ids) + + records_batch = self._tokenizer.pad( + { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + }, + padding='max_length', + max_length=max_input_length, + return_tensors='pt', + ) + + return self.generate_from_batch_records(dataset_name, records_batch, original_records) return [ - self._generate_from_single_record(record, original_record, dataset_name) + self.generate_from_single_record(dataset_name, record, original_record) for record, original_record in zip(records, original_records) ] - @staticmethod - def _postprocess(input_indices: torch.Tensor, output_indices: torch.Tensor, remove_prompt: bool) -> torch.Tensor: + def _postprocess( + self, + input_indices: torch.Tensor, + output_indices: torch.Tensor, + logits: torch.Tensor | None, + remove_prompt: bool, + only_answer_logits: bool, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, list[str]]: + processed_logits: torch.Tensor | None = None + processed_output_indices = output_indices.cpu() + + if logits is not None: + processed_logits = logits.cpu() + if only_answer_logits: + processed_logits = logits[:, input_indices.shape[1] :, :].cpu() + if remove_prompt: - return output_indices[:, input_indices.shape[1] :].cpu() - return output_indices.cpu() + processed_output_indices = output_indices[:, input_indices.shape[1] :] + + answers = self._decode(token_indices=processed_output_indices) + + answers_attention_mask = processed_output_indices != self._tokenizer.pad_token_id + + return processed_output_indices, answers_attention_mask, processed_logits, answers def _decode(self, token_indices: torch.Tensor) -> list[str]: return self._tokenizer.batch_decode( diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index acc4337..579491c 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -1,6 +1,7 @@ from typing import Any import torch +from transformers import BatchEncoding from turbo_alignment.dataset.chat.models import ChatDatasetRecord from turbo_alignment.generators.base import ChatGeneratorBase @@ -11,30 +12,16 @@ class ChatGenerator(ChatGeneratorBase[ChatDatasetRecord, ChatInferenceOutput]): - def _generate_from_batch_records( + def generate_from_batch_records( self, - records: list[dict[str, torch.Tensor]], - original_records: list[ChatDatasetRecord], dataset_name: str, + records_batch: dict[str, torch.Tensor] | BatchEncoding, + original_records: list[ChatDatasetRecord] | None = None, ) -> list[ChatInferenceOutput]: - input_ids = [record['input_ids'].tolist() for record in records] - attention_mask = [record['attention_mask'].tolist() for record in records] - - max_input_length = max(len(sample) for sample in input_ids) - - batch = self._tokenizer.pad( - { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - }, - padding='max_length', - max_length=max_input_length, - return_tensors='pt', - ) - - batched_input_ids = batch['input_ids'].to(self.device) - batched_attention_mask = batch['attention_mask'].to(self.device) + batched_input_ids = records_batch['input_ids'].to(self.device) + batched_attention_mask = records_batch['attention_mask'].to(self.device) + assert self._tokenizer.padding_side == 'left' output_indices = self._model.generate( inputs=batched_input_ids, attention_mask=batched_attention_mask, @@ -43,39 +30,60 @@ def _generate_from_batch_records( pad_token_id=self._tokenizer.pad_token_id, ) - postprocessed_output_indices = self._postprocess( + logits: torch.Tensor | None = None + + if self._custom_generation_settings.return_logits: + with torch.no_grad(): + logits = self._model(output_indices).logits + + postprocessed_output_indices, answers_attention_mask, postprocessed_logits, answers = self._postprocess( input_indices=batched_input_ids, output_indices=output_indices, + logits=logits, remove_prompt=self._custom_generation_settings.remove_prompt, + only_answer_logits=self._custom_generation_settings.only_answer_logits, ) - answers = self._decode(token_indices=postprocessed_output_indices) - - return [ - ChatInferenceOutput( - id=original_record.id, - dataset_name=dataset_name, - messages=original_record.messages, - label=original_record.label, - meta=original_record.meta, - answers=[ - AnswerMessage( - id='0', - content=answer, - input_token_ids=None, - answer_token_ids=None, - logits=None, + outputs = [] + for i, (input_ids, attention_mask) in enumerate(zip(batched_input_ids, batched_attention_mask)): + ans_batch_start = i * self._transformers_generator_parameters.num_return_sequences + ans_batch_end = (i + 1) * self._transformers_generator_parameters.num_return_sequences + batch_answer_tokens = postprocessed_output_indices[ans_batch_start:ans_batch_end, :] + batch_answer_attention_masks = answers_attention_mask[ans_batch_start:ans_batch_end, :] + batch_answers = answers[ans_batch_start:ans_batch_end] + + for j, (answer, answer_tokens, answer_attention_mask) in enumerate( + zip(batch_answers, batch_answer_tokens, batch_answer_attention_masks) + ): + answer_logits = None if postprocessed_logits is None else postprocessed_logits[i + j, :, :] + + outputs.append( + ChatInferenceOutput( + dataset_name=dataset_name, + messages=original_records[i].messages if original_records else None, + label=original_records[i].label if original_records else None, + meta=original_records[i].meta if original_records else None, + input_token_ids=input_ids, + input_attention_mask=attention_mask, + answers=[ + AnswerMessage( + id=str(i), + content=answer, + answer_token_ids=answer_tokens, + logits=answer_logits, + answer_attention_mask=answer_attention_mask, + ) + ], ) - ], - ) - for original_record, answer in zip(original_records, answers) - ] + ) + + return outputs - def _generate_from_single_record( + def generate_from_single_record( self, - record: dict[str, Any], - original_record: ChatDatasetRecord, dataset_name: str, + record: dict[str, Any], + original_record: ChatDatasetRecord | None = None, ) -> ChatInferenceOutput: input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device) attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self.device) @@ -88,52 +96,41 @@ def _generate_from_single_record( pad_token_id=self._tokenizer.pad_token_id, ) - postprocessed_output_indices = self._postprocess( - input_indices=input_ids, - output_indices=output_indices, - remove_prompt=self._custom_generation_settings.remove_prompt, - ) - - answers = self._decode(token_indices=postprocessed_output_indices) - logits: torch.Tensor | None = None - input_token_ids: torch.Tensor | None = None - answer_tokens_ids: torch.Tensor | None = None - if self._return_logits: + if self._custom_generation_settings.return_logits: with torch.no_grad(): logits = self._model(output_indices).logits.cpu() - answer_tokens_ids = postprocessed_output_indices - input_token_ids = input_ids + postprocessed_output_indices, answers_attention_mask, postprocessed_logits, answers = self._postprocess( + input_indices=input_ids, + output_indices=output_indices, + logits=logits, + remove_prompt=self._custom_generation_settings.remove_prompt, + only_answer_logits=self._custom_generation_settings.only_answer_logits, + ) - answer_messages = [ - AnswerMessage( - id=str(i), - content=a, - input_token_ids=input_token_ids, - answer_token_ids=a_t_ids.unsqueeze(0), - logits=l.unsqueeze(0), - ) - for i, (a, a_t_ids, l) in enumerate(zip(answers, answer_tokens_ids, logits)) # type: ignore[arg-type] - ] - else: - answer_messages = [ + answer_messages = [] + for i, (answer, answer_tokens, answer_attention_mask) in enumerate( + zip(answers, postprocessed_output_indices, answers_attention_mask) + ): + answer_logits = None if postprocessed_logits is None else postprocessed_logits[i, :, :].unsqueeze(0) + answer_messages.append( AnswerMessage( id=str(i), - content=a, - input_token_ids=input_token_ids, - answer_token_ids=answer_tokens_ids, - logits=logits, + content=answer, + answer_token_ids=answer_tokens.unsqueeze(0), + answer_attention_mask=answer_attention_mask, + logits=answer_logits, ) - for i, a in enumerate(answers) - ] + ) return ChatInferenceOutput( - id=original_record.id, dataset_name=dataset_name, messages=original_record.messages, label=original_record.label, meta=original_record.meta, answers=answer_messages, + input_token_ids=input_ids, + input_attention_mask=attention_mask, ) diff --git a/turbo_alignment/generators/classification.py b/turbo_alignment/generators/classification.py index e809e16..1bca7a9 100755 --- a/turbo_alignment/generators/classification.py +++ b/turbo_alignment/generators/classification.py @@ -16,8 +16,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): self._collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True) - def _generate_from_batch( - self, records: list[dict[str, Any]], original_records: list[ClassificationDatasetRecord], dataset_name: str + def generate_from_batch( + self, + dataset_name: str, + records: list[dict[str, Any]], + original_records: list[ClassificationDatasetRecord] | None = None, ) -> list[ClassificationInferenceOutput]: inputs = [{'input_ids': record['input_ids']} for record in records] diff --git a/turbo_alignment/generators/multimodal.py b/turbo_alignment/generators/multimodal.py index 0c64137..731b3f9 100755 --- a/turbo_alignment/generators/multimodal.py +++ b/turbo_alignment/generators/multimodal.py @@ -1,5 +1,5 @@ import torch -from transformers import PreTrainedTokenizerBase +from transformers import BatchEncoding, PreTrainedTokenizerBase from turbo_alignment.dataset.multimodal.models import MultimodalDatasetRecord from turbo_alignment.generators.base import ChatGeneratorBase @@ -26,19 +26,19 @@ def __init__( **kwargs, ) - def _generate_from_batch_records( + def generate_from_batch_records( self, - records: list[dict[str, torch.Tensor]], - original_records: list[MultimodalDatasetRecord], dataset_name: str, + records_batch: dict[str, torch.Tensor] | BatchEncoding, + original_records: list[MultimodalDatasetRecord] | None = None, ) -> list[MultimodalInferenceOutput]: raise ValueError('You can not use batch generation with multimodаl generator') - def _generate_from_single_record( + def generate_from_single_record( self, - record: dict[str, torch.Tensor], - original_record: MultimodalDatasetRecord, dataset_name: str, + record: dict[str, torch.Tensor], + original_record: MultimodalDatasetRecord | None = None, ) -> MultimodalInferenceOutput: input_ids = torch.unsqueeze(record['input_ids'], 0).to(self._model.language_model.device) attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self._model.language_model.device) @@ -60,7 +60,7 @@ def _generate_from_single_record( input_token_ids: torch.Tensor | None = None answer_tokens_ids: torch.Tensor | None = None - if self._return_logits: + if self._custom_generation_settings.return_logits: with torch.no_grad(): logits = self._model(output_indices).logits @@ -72,11 +72,11 @@ def _generate_from_single_record( dataset_name=dataset_name, messages=original_record.messages, label=original_record.label, + input_token_ids=input_token_ids, answers=[ AnswerMessage( id=str(i), content=a, - input_token_ids=input_token_ids, answer_token_ids=answer_tokens_ids, logits=logits, ) diff --git a/turbo_alignment/generators/rag.py b/turbo_alignment/generators/rag.py index 4203de8..25dd42e 100755 --- a/turbo_alignment/generators/rag.py +++ b/turbo_alignment/generators/rag.py @@ -1,28 +1,35 @@ -from typing import Any - import torch +from transformers import BatchEncoding from turbo_alignment.dataset.chat.models import ChatDatasetRecord -from turbo_alignment.generators.chat import ChatGenerator +from turbo_alignment.generators.chat import ChatGeneratorBase from turbo_alignment.settings.generators.outputs.chat import ( AnswerMessage, RagInferenceOutput, ) -class RagGenerator(ChatGenerator): - def _generate_from_single_record( +class RagGenerator(ChatGeneratorBase[ChatDatasetRecord, RagInferenceOutput]): + def generate_from_batch_records( + self, + dataset_name: str, + records_batch: dict[str, torch.Tensor] | BatchEncoding, + original_records: list[ChatDatasetRecord] | None = None, + ) -> list[RagInferenceOutput]: + raise ValueError('You can not use batch generation with RAG generator') + + def generate_from_single_record( self, - record: dict[str, Any], - original_record: ChatDatasetRecord, dataset_name: str, + record: dict[str, torch.Tensor], + original_record: ChatDatasetRecord | None = None, ) -> RagInferenceOutput: input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device) answer_indices, document_indices, doc_scores = self._model.generate( inputs=input_ids, generation_config=self._transformers_generator_parameters, - tokenizer=self._tokenizer.current_tokenizer, + tokenizer=self._tokenizer, pad_token_id=self._tokenizer.pad_token_id, ) diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py index ce9a9c8..b9d3892 100755 --- a/turbo_alignment/generators/rm.py +++ b/turbo_alignment/generators/rm.py @@ -1,95 +1,69 @@ from typing import Any import torch -from torch import nn -from transformers import DataCollatorWithPadding, PreTrainedTokenizerBase +from transformers import BatchEncoding, DataCollatorWithPadding, PreTrainedTokenizerBase -from turbo_alignment.dataset.pair_preferences import PairPreferenceRecord from turbo_alignment.dataset.sampling.models import SamplingDatasetRecord from turbo_alignment.generators.base import BaseGenerator -from turbo_alignment.settings.generators.outputs.rm import ( - RMPairInferenceOutput, - RMSamplingInferenceOutput, -) - - -class RMPairGenerator(BaseGenerator[PairPreferenceRecord, RMPairInferenceOutput]): - def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs): - self._collator = DataCollatorWithPadding(tokenizer=tokenizer) - - super().__init__(tokenizer=tokenizer, **kwargs) - - def _generate_from_batch( - self, records: list[dict[str, Any]], original_records: list[PairPreferenceRecord], dataset_name: str - ) -> list[RMPairInferenceOutput]: - merged_inputs = [r['inputs_w'] for r in records] + [r['inputs_l'] for r in records] - batch = self._collator(merged_inputs) - input_ids = batch['input_ids'].to(self.device) - attn_mask = batch['attention_mask'].to(self.device) - - with torch.no_grad(): - rewards = self._model(input_ids=input_ids, attention_mask=attn_mask).logits.cpu() - rewards_w, rewards_l = rewards[: len(records)], rewards[len(records) :] - - return [ - RMPairInferenceOutput( - id=record.id, - context=record.context, - answer_w=record.answer_w, - answer_l=record.answer_l, - reward_w=reward_w.item(), - reward_l=reward_l.item(), - dataset_name=dataset_name, - ) - for record, reward_w, reward_l in zip(original_records, rewards_w, rewards_l) - ] +from turbo_alignment.settings.generators.outputs.rm import RMSamplingInferenceOutput class RMSamplingGenerator(BaseGenerator[SamplingDatasetRecord, RMSamplingInferenceOutput]): - def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int, **kwargs): + def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, **kwargs): self._collator = DataCollatorWithPadding(tokenizer=tokenizer) self._micro_batch = micro_batch super().__init__(tokenizer=tokenizer, **kwargs) - def _generate_from_batch( - self, records: list[dict[str, Any]], original_records: list[SamplingDatasetRecord], dataset_name: str + def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEncoding) -> torch.Tensor: + with torch.no_grad(): + rewards = self._model(**records).logits.cpu() + return rewards # .squeeze() + + def generate_from_batch( + self, + dataset_name: str, + records: list[dict[str, Any]], + original_records: list[SamplingDatasetRecord] | None = None, ) -> list[RMSamplingInferenceOutput]: + self._tokenizer.padding_side = 'left' + self._tokenizer.pad_token_id = self._tokenizer.pad_token_id + merged_inputs = [inputs for record in records for key, inputs in record['answers'].items()] - if len(merged_inputs) == 0: - return [] + input_ids = [record['input_ids'].tolist() for record in merged_inputs] + attention_mask = [record['attention_mask'].tolist() for record in merged_inputs] rewards = [] - with torch.no_grad(): - input_ids = nn.utils.rnn.pad_sequence( - [item['input_ids'] for item in merged_inputs], - padding_value=self._tokenizer.pad_token_id, - batch_first=True, + for i in range(0, len(input_ids), self._micro_batch): + input_ids_batch = input_ids[i : i + self._micro_batch] + attn_mask_batch = attention_mask[i : i + self._micro_batch] + + max_input_length = max(len(sample) for sample in input_ids_batch) + + records_batch = self._tokenizer.pad( + { + 'input_ids': input_ids_batch, + 'attention_mask': attn_mask_batch, + }, + padding='max_length', + max_length=max_input_length, + return_tensors='pt', + ).to(self.device) + + rewards.extend(self.generate_from_batch_records(records_batch).tolist()) + + outputs = [] + for i, record in enumerate(original_records): + record_rewards = {answer.id: rewards[i + j] for j, answer in enumerate(record.answers)} + + outputs.append( + RMSamplingInferenceOutput( + id=record.id, + answers=record.answers, + dataset_name=record.dataset_name, + messages=record.messages, + rewards=record_rewards, + ) ) - attn_mask = nn.utils.rnn.pad_sequence( - [item['attention_mask'] for item in merged_inputs], - padding_value=0, - batch_first=True, - ) - for i in range(0, len(input_ids), self._micro_batch): - input_ids_batch = input_ids[i : i + self._micro_batch].to(self.device) - attn_mask_batch = attn_mask[i : i + self._micro_batch].to(self.device) - rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits.cpu()) - - rewards = torch.cat(rewards, dim=0) - record_rewards = [ - {answer_id: reward.item() for answer_id, reward in zip(record['answers'].keys(), rewards)} - for record in records - ] - - return [ - RMSamplingInferenceOutput( - id=record.id, - rewards=rewards, - messages=record.messages, - dataset_name=dataset_name, - answers=record.answers, - ) - for record, rewards in zip(original_records, record_rewards) - ] + return outputs diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index ff52de3..03dfd22 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -2,7 +2,11 @@ import torch from transformers import PreTrainedTokenizerBase -from vllm import LLM, SamplingParams + +try: + from vllm import LLM, SamplingParams +except ImportError as exc: + raise exc from turbo_alignment.dataset.chat import ChatDatasetRecord from turbo_alignment.generators.base import BaseGenerator @@ -22,7 +26,6 @@ def __init__( model: LLM, tokenizer: PreTrainedTokenizerBase, batch: int, - return_logits: bool = False, ): model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) @@ -52,10 +55,13 @@ def __init__( **beam_search_params, ) - self._return_logits = return_logits + self._custom_generation_settings = custom_generation_settings - def _generate_from_batch( - self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str + def generate_from_batch( + self, + dataset_name: str, + records: list[dict[str, Any]], + original_records: list[ChatDatasetRecord] | None = None, ) -> list[ChatInferenceOutput]: input_ids = [record['input_ids'].tolist() for record in records] request_outputs = self._model.generate( @@ -74,7 +80,7 @@ def _generate_from_batch( content=a.text, sequence_score=a.cumulative_logprob, ) - if self._return_logits: + if self._custom_generation_settings.return_logits: ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0) ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0) diff --git a/turbo_alignment/pipelines/inference/chat.py b/turbo_alignment/pipelines/inference/chat.py index a40df0f..0a3a051 100755 --- a/turbo_alignment/pipelines/inference/chat.py +++ b/turbo_alignment/pipelines/inference/chat.py @@ -26,7 +26,10 @@ def _get_single_inference_settings( ) if model_inference_settings.use_vllm: - import vllm + try: + import vllm + except ImportError as exc: + raise exc from turbo_alignment.generators.vllm_chat import VLLMChatGenerator diff --git a/turbo_alignment/pipelines/train/__init__.py b/turbo_alignment/pipelines/train/__init__.py index ac24318..9511009 100755 --- a/turbo_alignment/pipelines/train/__init__.py +++ b/turbo_alignment/pipelines/train/__init__.py @@ -4,5 +4,6 @@ from .kto import TrainKTOStrategy from .multimodal import TrainMultimodalStrategy from .rag import TrainRAGStrategy +from .reinforce import TrainREINFORCEStrategy from .rm import TrainRMStrategy from .sft import TrainSFTStrategy diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 0f98602..f3c0b20 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -136,6 +136,7 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa def run(self, experiment_settings: ExperimentSettingsT) -> None: training_args = self._get_training_args(experiment_settings) + print('HERE!!!!!!!!!!!!!!!!!!!!!!!!') self.tokenizer = self._load_tokenizer(experiment_settings) logger.info('Tokenizer is loaded!') diff --git a/turbo_alignment/pipelines/train/reinforce.py b/turbo_alignment/pipelines/train/reinforce.py new file mode 100644 index 0000000..ef04fe3 --- /dev/null +++ b/turbo_alignment/pipelines/train/reinforce.py @@ -0,0 +1,132 @@ +from typing import Callable + +from torch.utils.data import ConcatDataset, Dataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.data.data_collator import ( + DataCollatorForTokenClassification, + DataCollatorMixin, +) + +from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback +from turbo_alignment.common.logging import get_project_logger +from turbo_alignment.common.tf.loaders import load_model +from turbo_alignment.constants import TRAINER_LOGS_FOLDER +from turbo_alignment.dataset.chat.chat import InferenceChatDataset +from turbo_alignment.dataset.loader import DatasetLoader +from turbo_alignment.metrics.metric import Metric +from turbo_alignment.metrics.registry import MetricSettingsRegistry +from turbo_alignment.pipelines.train.base import BaseTrainStrategy +from turbo_alignment.settings.datasets.base import DatasetStrategy +from turbo_alignment.settings.pipelines.train.reinforce import ( + REINFORCETrainExperimentSettings, +) +from turbo_alignment.trainers.online.reinforce import ( + REINFORCETrainer, + REINFORCETrainingArguments, +) + +logger = get_project_logger() + + +class TrainREINFORCEStrategy(BaseTrainStrategy[REINFORCETrainExperimentSettings]): + @staticmethod + def _get_data_collator( + experiment_settings: REINFORCETrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> Callable: + return DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8) + + @staticmethod + def _get_cherry_pick_callback( + experiment_settings: REINFORCETrainExperimentSettings, + tokenizer: PreTrainedTokenizerBase, + **kwargs, + ) -> ChatCherryPickCallback: + cherry_pick_settings = experiment_settings.cherry_pick_settings + + cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + cherry_pick_settings.dataset_settings, + tokenizer=tokenizer, + strategy=DatasetStrategy.INFERENCE, + ) + + metrics = [ + Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters)) + for metric in cherry_pick_settings.metric_settings + ] + + return ChatCherryPickCallback( + cherry_pick_settings=cherry_pick_settings, + datasets=cherry_pick_datasets, + metrics=metrics, + ) + + @staticmethod + def _get_training_args( + experiment_settings: REINFORCETrainExperimentSettings, + ) -> REINFORCETrainingArguments: + return REINFORCETrainingArguments( + output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER), + label_names=[], + remove_unused_columns=False, + **experiment_settings.trainer_settings.dict(), + ) + + @staticmethod + def _get_trainer( + training_args: REINFORCETrainingArguments, + experiment_settings: REINFORCETrainExperimentSettings, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + val_dataset: Dataset, + data_collator: Callable, + ): + # TODO: different tokenizer for reward model + + ref_model = load_model(experiment_settings.model_settings, tokenizer) + for _, param in ref_model.named_parameters(): + param.requires_grad = False + + ref_model.eval() + + reward_model = load_model(experiment_settings.reward_model_settings, tokenizer) + for _, param in reward_model.named_parameters(): + param.requires_grad = False + + reward_model.eval() + + return REINFORCETrainer( + args=training_args, + tokenizer=tokenizer, + policy=model, + ref_model=ref_model, + train_dataset=train_dataset, + eval_dataset=val_dataset, + data_collator=data_collator, + reward_model=reward_model, + callbacks=[], + ) + + def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None: + logger.info(f'Train sample input_ids:\n{dataset[0]}') + logger.info(f'Train sample example:\n{self.tokenizer.decode(dataset[0]["input_ids"])}') + + def _get_datasets(self, experiment_settings: REINFORCETrainExperimentSettings) -> tuple[Dataset, Dataset]: + train_dataset: ConcatDataset = ConcatDataset( + datasets=DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + experiment_settings.train_dataset_settings, + tokenizer=self.tokenizer, + strategy=DatasetStrategy.INFERENCE, + ) + ) + + val_dataset: ConcatDataset = ConcatDataset( + datasets=DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets( + experiment_settings.val_dataset_settings, + tokenizer=self.tokenizer, + strategy=DatasetStrategy.INFERENCE, + ) + ) + return train_dataset, val_dataset diff --git a/turbo_alignment/settings/generators/chat.py b/turbo_alignment/settings/generators/chat.py index 01bcc09..b672bec 100755 --- a/turbo_alignment/settings/generators/chat.py +++ b/turbo_alignment/settings/generators/chat.py @@ -4,4 +4,6 @@ class CustomChatGenerationSettings(ExtraFieldsNotAllowedBaseModel): skip_special_tokens: bool = True remove_prompt: bool = True + only_answer_logits: bool = True batch: int = 1 + return_logits: bool = False diff --git a/turbo_alignment/settings/generators/outputs/chat.py b/turbo_alignment/settings/generators/outputs/chat.py index 3d4c392..78719a1 100755 --- a/turbo_alignment/settings/generators/outputs/chat.py +++ b/turbo_alignment/settings/generators/outputs/chat.py @@ -1,6 +1,8 @@ +from typing import Any + import torch -from turbo_alignment.dataset.chat import ChatDatasetRecord +from turbo_alignment.dataset.chat import ChatMessage from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.generators.outputs.base import BaseInferenceOutput @@ -8,17 +10,26 @@ class AnswerMessage(ExtraFieldsNotAllowedBaseModel): id: str content: str + answer_token_ids: torch.Tensor + answer_attention_mask: torch.Tensor sequence_score: float | None = None - input_token_ids: torch.Tensor | None = None - answer_token_ids: torch.Tensor | None = None logits: torch.Tensor | None = None class Config: arbitrary_types_allowed = True -class ChatInferenceOutput(BaseInferenceOutput, ChatDatasetRecord): +class ChatInferenceOutput(BaseInferenceOutput): + id: str | None = None + input_token_ids: torch.Tensor | None = None + input_attention_mask: torch.Tensor | None = None answers: list[AnswerMessage] + messages: list[ChatMessage] | None + label: str | None = None + meta: dict[str, Any] | None = None + + class Config: + arbitrary_types_allowed = True class RagInferenceOutput(ChatInferenceOutput): diff --git a/turbo_alignment/settings/generators/outputs/multimodal.py b/turbo_alignment/settings/generators/outputs/multimodal.py index 6238397..69c21c3 100755 --- a/turbo_alignment/settings/generators/outputs/multimodal.py +++ b/turbo_alignment/settings/generators/outputs/multimodal.py @@ -1,5 +1,6 @@ from typing import Annotated +import torch from pydantic import Field from turbo_alignment.dataset.base.models import DatasetRecord @@ -18,3 +19,7 @@ class MultimodalInferenceOutput(BaseInferenceOutput, DatasetRecord): ] label: str | None = None answers: list[AnswerMessage] + input_token_ids: torch.Tensor | None = None + + class Config: + arbitrary_types_allowed = True diff --git a/turbo_alignment/settings/online.py b/turbo_alignment/settings/online.py new file mode 100644 index 0000000..52d14c7 --- /dev/null +++ b/turbo_alignment/settings/online.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class LLMActorType(str, Enum): + LOCAL_TRANSFORMERS = 'local_transformers' + DISTRIBUTED_VLLM = 'distributed_vllm' + + +class CriticType(str, Enum): + LOCAL_TRANSFORMERS = 'local_transformers' + DISTRIBUTED_VLLM = 'distributed_vllm' + + +class RewardProcessorType(str, Enum): + RLOO = 'rloo' diff --git a/turbo_alignment/settings/pipelines/__init__.py b/turbo_alignment/settings/pipelines/__init__.py index 1b6143a..0a4d7c0 100755 --- a/turbo_alignment/settings/pipelines/__init__.py +++ b/turbo_alignment/settings/pipelines/__init__.py @@ -14,5 +14,8 @@ MultimodalTrainExperimentSettings, ) from turbo_alignment.settings.pipelines.train.rag import RAGTrainExperimentSettings +from turbo_alignment.settings.pipelines.train.reinforce import ( + REINFORCETrainExperimentSettings, +) from turbo_alignment.settings.pipelines.train.rm import RMTrainExperimentSettings from turbo_alignment.settings.pipelines.train.sft import SftTrainExperimentSettings diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py new file mode 100644 index 0000000..be1c71a --- /dev/null +++ b/turbo_alignment/settings/pipelines/train/reinforce.py @@ -0,0 +1,37 @@ +from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings +from turbo_alignment.settings.datasets import ChatMultiDatasetSettings +from turbo_alignment.settings.model import ( + PreTrainedAdaptersModelSettings, + PreTrainedModelSettings, +) +from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings +from turbo_alignment.settings.tf.trainer import TrainerSettings + + +class REINFORCETrainerSettings(TrainerSettings): + max_tokens_count: int = 1024 + stop_token: str = '' + + penalty_reward_value: float = 0.1 + clip_rewards_min: float = 0.1 + clip_rewards_max: float = 1.0 + kl_coef: float = 0.05 + mean_baseline_coef: float = 0.1 + + num_generations: int = 3 + num_samples_for_reward_stats: int = 0 + + non_eos_penalty: bool = True + temperature: float | None = None + whiten_rewards: bool = False + + +class REINFORCETrainExperimentSettings(BaseTrainExperimentSettings): + train_dataset_settings: ChatMultiDatasetSettings + val_dataset_settings: ChatMultiDatasetSettings + + cherry_pick_settings: ChatCherryPickSettings + + reward_model_settings: PreTrainedModelSettings | PreTrainedAdaptersModelSettings + + trainer_settings: REINFORCETrainerSettings diff --git a/turbo_alignment/settings/tf/generation.py b/turbo_alignment/settings/tf/generation.py index 14537d9..f9a5eae 100755 --- a/turbo_alignment/settings/tf/generation.py +++ b/turbo_alignment/settings/tf/generation.py @@ -3,7 +3,7 @@ class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel): num_beams: int = 1 - max_new_tokens: int = 15 + max_new_tokens: int | None = None repetition_penalty: float = 1.0 num_return_sequences: int = 1 do_sample: bool = True @@ -11,3 +11,4 @@ class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel): top_k: int = 50 temperature: float = 1.0 stop_strings: str | list[str] = '' + max_length: int = 20 diff --git a/turbo_alignment/trainers/online/ray/vllm_engine.py b/turbo_alignment/trainers/online/ray/vllm_engine.py new file mode 100644 index 0000000..b851ae1 --- /dev/null +++ b/turbo_alignment/trainers/online/ray/vllm_engine.py @@ -0,0 +1,118 @@ +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from turbo_alignment.common.logging import get_project_logger + +logger = get_project_logger() + + +@ray.remote +class LLMRayActor: + def __init__(self, *args, **kwargs): + import vllm + + self.__version__ = vllm.__version__ + assert self.__version__ >= '0.4.1', 'OpenRLHF only supports vLLM >= 0.4.1' + + self.use_gpu_executor = kwargs['tensor_parallel_size'] == 1 + + # See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + if self.use_gpu_executor: + from turbo_alignment.trainers.online.ray.vllm_worker_wrap import WorkerWrap + + vllm.worker.worker.Worker = WorkerWrap + else: + # RayGPUExecutor + # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 + kwargs['worker_use_ray'] = True + + if vllm.__version__ > '0.4.1': + RayWorkerWrapperPath = vllm.executor.ray_utils + else: + RayWorkerWrapperPath = vllm.engine.ray_utils + + class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper): + def __init__(self, *args, **kwargs) -> None: + kwargs['worker_module_name'] = 'openrlhf.trainer.ray.vllm_worker_wrap' + kwargs['worker_class_name'] = 'WorkerWrap' + super().__init__(*args, **kwargs) + + RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper + + self.llm = vllm.LLM(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.llm.generate(*args, **kwargs) + + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): + if self.use_gpu_executor: + return self.llm.llm_engine.model_executor.driver_worker.init_process_group( + master_address, master_port, rank_offset, world_size, group_name, backend + ) + else: + return self.llm.llm_engine.model_executor._run_workers( + 'init_process_group', master_address, master_port, rank_offset, world_size, group_name, backend + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + self.stop_remote_worker_execution_loop() + + if self.use_gpu_executor: + return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache) + else: + return self.llm.llm_engine.model_executor._run_workers('update_weight', name, dtype, shape, empty_cache) + + def stop_remote_worker_execution_loop(self): + # Fix error for using 2 communication group + # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4 + if self.__version__ > '0.4.2': + self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop() + + +def create_vllm_engines( + num_engines: int, + tensor_parallel_size: int, + pretrain: str, + seed: int, + enable_prefix_caching: bool, + max_model_len: int, +): + vllm_engines = [] + for i in range(num_engines): + # When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it. + num_gpus = int(tensor_parallel_size == 1) + scheduling_strategy = None + + if tensor_parallel_size > 1: + bundles = [{'GPU': 1, 'CPU': 1}] * tensor_parallel_size + pg = placement_group(bundles) + ray.get(pg.ready()) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0 + ) + + vllm_engines.append( + LLMRayActor.options( + num_cpus=1, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote( + pretrain, + trust_remote_code=True, + tensor_parallel_size=tensor_parallel_size, + dtype='bfloat16', + seed=seed + i, + enable_prefix_caching=enable_prefix_caching, + max_model_len=max_model_len, + ) + ) + + return vllm_engines + + +if __name__ == '__main__': + llm = LLMRayActor.remote('meta-llama/Llama-2-7b-chat-hf', tensor_parallel_size=4) + output = ray.get(llm.generate.remote('San Franciso is a')) + print(f'output: {output}') diff --git a/turbo_alignment/trainers/online/ray/vllm_worker_wrap.py b/turbo_alignment/trainers/online/ray/vllm_worker_wrap.py new file mode 100644 index 0000000..c40a39d --- /dev/null +++ b/turbo_alignment/trainers/online/ray/vllm_worker_wrap.py @@ -0,0 +1,43 @@ +import torch +from vllm.worker.worker import Worker + +from turbo_alignment.common.distributed import init_process_group +from turbo_alignment.common.logging import get_project_logger + +logger = get_project_logger() + + +class WorkerWrap(Worker): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend='nccl'): + """Init torch process group for model weights update""" + assert torch.distributed.is_initialized(), 'default torch process group must be initialized' + assert group_name != '', 'group name must not be empty' + + rank = torch.distributed.get_rank() + rank_offset + self._model_update_group = init_process_group( + backend=backend, + init_method=f'tcp://{master_address}:{master_port}', + world_size=world_size, + rank=rank, + group_name=group_name, + ) + print( + f'init_process_group: master_address={master_address}, master_port={master_port}, ', + f'rank={rank}, world_size={world_size}, group_name={group_name}', + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + if torch.distributed.get_rank() == 0: + print(f'update weight: {name}, dtype: {dtype}, shape: {shape}') + + assert dtype == self.model_config.dtype, f'mismatch dtype: src {dtype}, dst {self.model_config.dtype}' + weight = torch.empty(shape, dtype=dtype, device='cuda') + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + # TODO: should we empty cache if all weights have updated? + # if empty_cache: + # torch.cuda.empty_cache() diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py new file mode 100644 index 0000000..85c8e20 --- /dev/null +++ b/turbo_alignment/trainers/online/reinforce.py @@ -0,0 +1,388 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils +import torch.utils.data +from datasets import Dataset +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + TrainerCallback, + TrainingArguments, GenerationConfig, +) + +from turbo_alignment.common.distributed import ( + get_global_mean, + get_global_std, + get_log_mean_std, +) +from turbo_alignment.generators import ChatGenerator, RMSamplingGenerator +from turbo_alignment.generators.base import ChatGeneratorBase +from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings +from turbo_alignment.settings.generators.outputs.chat import ChatInferenceOutput +from turbo_alignment.settings.online import ( + CriticType, + LLMActorType, + RewardProcessorType, +) +from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings +from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer +from turbo_alignment.trainers.online.reward_processor import RewardProcessor +from turbo_alignment.trainers.utils import prepare_model + + +# FIXME +@dataclass +class REINFORCETrainingArguments(TrainingArguments): + max_tokens_count: int = 1024 + stop_token: str = '' + + penalty_reward_value: float = 0.1 + clip_rewards_min: float = 0.1 + clip_rewards_max: float = 1.0 + kl_coef: float = 0.05 + mean_baseline_coef: float = 0.1 + + num_generations: int = 3 + num_samples_for_reward_stats: int = 0 + + non_eos_penalty: bool = True + temperature: float | None = None + whiten_rewards: bool = False + + actor_type: LLMActorType = LLMActorType.LOCAL_TRANSFORMERS + critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS + + reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO + + +class REINFORCETrainer(MultiGPUCherryPicksTrainer): + def __init__( + self, + args: REINFORCETrainingArguments, + tokenizer: PreTrainedTokenizerBase, + policy: PreTrainedModel | torch.nn.Module, + reward_model: PreTrainedModel | torch.nn.Module, + train_dataset: Dataset, + eval_dataset: Dataset, + ref_model: PreTrainedModel | torch.nn.Module | None = None, + callbacks: list[TrainerCallback] | None = None, + **kwargs, + ) -> None: + super().__init__( + model=policy, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + callbacks=callbacks, + **kwargs, + ) + + self.ref_model = ref_model + if self.ref_model is not None: + self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled) + + self.reward_model = reward_model + + self._stored_metrics: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) + self.kl_coef = args.kl_coef + + self.gen_id = 1 + + # Take gradient accumulation into account + # In this way, the effective number of gradient + # updates taken into account stays the same + effective_num_previous_samples = 1 / (1 - self.args.mean_baseline_coef) + self.args.mean_baseline_coef = 1 - 1 / (effective_num_previous_samples * self.args.gradient_accumulation_steps) + + self.stop_generation_token_id = tokenizer.encode(args.stop_token, add_special_tokens=False) + assert len(self.stop_generation_token_id) == 1, self.stop_generation_token_id + + self.generator_transformers_settings = GeneratorTransformersSettings( + temperature=args.temperature, + max_length=args.max_tokens_count, + num_return_sequences=args.num_generations, + num_beams=args.num_generations, + do_sample=True, + ) + self.generator_custom_settings = CustomChatGenerationSettings( + batch=self.args.per_device_train_batch_size, + remove_prompt=False, + only_answer_logits=True, + return_logits=True, + ) + self.reward_processor = RewardProcessor.by_name(args.reward_processor_type)( + mean_baseline_coef=args.mean_baseline_coef, + num_generations=args.num_generations, + ) + + self.norm_reward_mean, self.norm_reward_std = self.reward_stats( + model=self.model, dataloader=self.get_train_dataloader() + ) + + # FIXME: some base class instead of RMSamplingGenerator (join with ChatGeneratorBase?) + def _get_rm_generator(self, reward_model: torch.nn.Module | PreTrainedModel) -> RMSamplingGenerator: + match self.args.critic_type: + case CriticType.LOCAL_TRANSFORMERS: + generator = RMSamplingGenerator( + model=reward_model, + tokenizer=self.tokenizer, + accelerator=self.accelerator, + ) + case CriticType.DISTRIBUTED_VLLM: + generator = ... + case _: + raise ValueError(f'Critic {self.args.critic_type} is not supported') + return generator + + def _get_chat_generator(self, model: torch.nn.Module | PreTrainedModel) -> ChatGeneratorBase: + match self.args.actor_type: + case LLMActorType.LOCAL_TRANSFORMERS: + generator = ChatGenerator( + model=model, + tokenizer=self.tokenizer, + transformers_settings=self.generator_transformers_settings, + custom_generation_settings=self.generator_custom_settings, + accelerator=self.accelerator, + ) + case LLMActorType.DISTRIBUTED_VLLM: + generator = ... + case _: + raise ValueError(f'LLM Actor {self.args.actor_type} is not supported') + return generator + + def get_answers_and_rewards( + self, + model: torch.nn.Module | PreTrainedModel, + inputs: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + generator = self._get_chat_generator(model) + + generations: list[ChatInferenceOutput] = generator.generate_from_batch_records( + dataset_name='online', records_batch=inputs + ) + + response_ids = torch.stack([ans.answer_token_ids for g in generations for ans in g.answers]) + response_attention_mask = torch.stack([ans.answer_attention_mask for g in generations for ans in g.answers]) + + response_tokens_mask = torch.zeros(response_ids.shape, dtype=torch.float32) + response_tokens_mask[:, generations[0].input_token_ids.shape[0] :] = 1.0 + + position_ids = (response_attention_mask.cumsum(-1) - 1).clamp(min=0) + position_ids.masked_fill_(response_attention_mask.to(torch.bool) == 0, 0) + + rm_inputs = { + 'input_ids': response_ids, + 'attention_mask': response_attention_mask, + 'position_ids': position_ids, + } + + critic_generator = self._get_rm_generator(self.reward_model) + rewards = critic_generator.generate_from_batch_records(rm_inputs) + + return response_ids, response_attention_mask, response_tokens_mask, position_ids, rewards + + def get_batch_loss_metrics( + self, + model: torch.nn.Module | PreTrainedModel, + inputs: dict[str, torch.Tensor], + train_eval: Literal['train', 'eval'] = 'train', + ): + with torch.no_grad(): + ( + query_response, + attention_mask, + response_tokens_mask, + position_ids, + rewards, + ) = self.get_answers_and_rewards( + model=self.accelerator.unwrap_model(model), + inputs=inputs, + ) + + ref_logprobs = self.get_logprobs( + model=self.ref_model, + input_ids=query_response, + attention_mask=attention_mask, + position_ids=position_ids, + loss_mask=response_tokens_mask.to(torch.float32), + ) + + rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response) + + logprobs = self.get_logprobs( + model=model, + input_ids=query_response, + attention_mask=attention_mask, + position_ids=position_ids, + loss_mask=response_tokens_mask, + ) + + with torch.no_grad(): + kl_term = logprobs.detach() - ref_logprobs + regularized_rewards = rewards - self.kl_coef * kl_term + + baselined_reward, baseline_metrics = self.reward_processor.baseline_rewards(rewards=regularized_rewards) + + loss = -baselined_reward * logprobs + + assert len(loss.shape) == 1, loss.shape + + with torch.no_grad(): + metrics = {} + for tensor, name in zip( + [ + rewards, + regularized_rewards, + baselined_reward, + loss.detach(), + kl_term.detach(), + logprobs.detach(), + 1 - valid_mask.float(), + ], + [ + 'rewards', + 'regularized_rewards', + 'baselined_rewards', + 'loss', + 'kl_term', + 'logprobs', + 'invalid_rewards', + ], + ): + metrics = { + **metrics, + **get_log_mean_std(tensor, name, train_eval), + } + metrics[f'{train_eval}/kl_coef'] = self.kl_coef + + for k, v in rewards_metrics.items(): + metrics[f'{train_eval}/{k}'] = v + for k, v in baseline_metrics.items(): + metrics[f'{train_eval}/{k}'] = v + + return loss.mean(), metrics + + def compute_loss(self, model, inputs, return_outputs: bool = False): + loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train') + + self.store_metrics(metrics=metrics, train_eval='train') + + return (loss, metrics) if return_outputs else loss + + def prediction_step( + self, + model: Union[PreTrainedModel, torch.nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + with torch.no_grad(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, 'eval') + + self.store_metrics(metrics, train_eval='eval') + + return loss.detach(), None, None + + @torch.no_grad() + def reward_stats( + self, model: torch.nn.Module | PreTrainedModel, dataloader: torch.utils.data.DataLoader + ) -> tuple[float, float]: + if self.args.num_samples_for_reward_stats == 0: + norm_reward_mean = 0 + norm_reward_std = 1 + else: + all_rewards = [] + samples_processed = 0 + for batch in dataloader: + _, _, _, _, rewards = self.get_answers_and_rewards( + model=self.accelerator.unwrap_model(model), + inputs=batch, + ) + rewards, _ = self.reward_processor.postprocess_rewards(rewards=rewards) + + all_rewards.append(rewards) + samples_processed += len(batch['input_ids']) * self.accelerator.num_processes + + if samples_processed > self.args.num_samples_for_reward_stats: + break + + all_rewards = torch.cat(all_rewards, 0) + + norm_reward_mean = get_global_mean(all_rewards) + norm_reward_std = get_global_std(all_rewards, mean=norm_reward_mean) + + return norm_reward_mean, norm_reward_std + + def get_logprobs( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + loss_mask: torch.Tensor, + ): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ).logits[:, :-1] + logits /= self.args.temperature + + all_logprob = F.log_softmax(logits, dim=-1) + + logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1) + logprob[~loss_mask[:, 1:].to(torch.bool)] = 0 + + return logprob.sum(-1) + + def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]: + if self.args.non_eos_penalty: + invalid_mask = query_response[:, -1] != self.stop_generation_token_id[0] + rewards[invalid_mask] = self.args.penalty_reward_value + + return rewards, ~invalid_mask + + return rewards, torch.ones_like(rewards).to(torch.bool) + + def process_rewards(self, rewards, query_response) -> tuple[torch.Tensor, torch.Tensor, Any]: + """ + Second, this token should be the . + Otherwise, return a fixed reward of -1. + """ + + rewards, reward_metrics = self.reward_processor.postprocess_rewards(rewards=rewards) + rewards = rewards.squeeze(-1) + reward_metrics['normalizing_reward_mean'] = self.norm_reward_mean + reward_metrics['normalizing_reward_std'] = self.norm_reward_std + rewards = (rewards - self.norm_reward_mean) / self.norm_reward_std + # boundness is required: https://arxiv.org/pdf/2406.01462 + rewards = torch.clamp(rewards, self.args.clip_rewards_min, self.args.clip_rewards_max) + + rewards, valid_mask = self.fill_nonvalid_rewards(rewards=rewards, query_response=query_response) + + return rewards, valid_mask, reward_metrics + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal['train', 'eval'] = 'train') -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs: Dict[str, float]) -> None: + main_keys = list(self._stored_metrics.keys()) + for main_key in main_keys: + for key, metrics in self._stored_metrics[main_key].items(): + logs[key] = torch.tensor(metrics).float().cpu().mean().item() + + if main_key == 'train': + logs['train/global_step'] = int(self.state.global_step) + else: + logs['eval/global_step'] = int(self.state.global_step // self.args.eval_steps) + + del self._stored_metrics[main_key] + + return super().log(logs) # pylint: disable=no-member diff --git a/turbo_alignment/trainers/online/reward_actor.py b/turbo_alignment/trainers/online/reward_actor.py new file mode 100644 index 0000000..e69de29 diff --git a/turbo_alignment/trainers/online/reward_processor.py b/turbo_alignment/trainers/online/reward_processor.py new file mode 100644 index 0000000..50d126c --- /dev/null +++ b/turbo_alignment/trainers/online/reward_processor.py @@ -0,0 +1,67 @@ +from abc import ABC + +import torch +from allenai_common import Registrable + +from turbo_alignment.common.distributed import get_global_mean, get_log_mean_std +from turbo_alignment.settings.online import RewardProcessorType + + +class RewardProcessor(ABC, Registrable): + def __init__( + self, + mean_baseline_coef: float, + num_generations: int, + ) -> None: + self.mean_reward: float | None = None + self.mean_baseline_coef = mean_baseline_coef + self.num_generations = num_generations + + def postprocess_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + raise NotImplementedError + + def baseline_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + baseline: float = self.mean_reward if self.mean_reward is not None else 0 + advantages: torch.Tensor = rewards - baseline + + with torch.no_grad(): + metrics: dict[str, float | int] = {'baseline_mean': baseline} + + self.update_mean_reward(rewards=rewards) + + return advantages, metrics + + @torch.no_grad() + def update_mean_reward(self, rewards: torch.Tensor): + global_mean_reward: float = get_global_mean(rewards) + + if self.mean_reward is None: + self.mean_reward = global_mean_reward + else: + self.mean_reward = ( + self.mean_baseline_coef * self.mean_reward + (1 - self.mean_baseline_coef) * global_mean_reward + ) + + +@RewardProcessor.register(RewardProcessorType.RLOO) +class RLOORewardProcessor(RewardProcessor): + def postprocess_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + rewards = rewards[:, 0].unsqueeze(-1) # values are at 1 + + with torch.no_grad(): + metrics: dict[str, float] = get_log_mean_std(rewards, 'real_reward') + + return rewards, metrics + + def baseline_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: + rewards = rewards.reshape(-1, self.num_generations) + baseline: torch.Tensor = (rewards.sum(-1).unsqueeze(-1) - rewards) / (self.num_generations - 1) + rloo_advantages: torch.Tensor = (rewards - baseline).flatten() + + with torch.no_grad(): + metrics: dict[str, float] = { + **get_log_mean_std(baseline, 'baseline'), + **get_log_mean_std(baseline.std(-1), 'baseline_inner_std'), + } + + return rloo_advantages, metrics