diff --git a/configs/exp/train/dpo/dpo.json b/configs/exp/train/dpo/dpo.json
index ba6e9fe..52c8b83 100755
--- a/configs/exp/train/dpo/dpo.json
+++ b/configs/exp/train/dpo/dpo.json
@@ -100,7 +100,7 @@
"bos_token": "<|begin_of_text|>",
"eos_token": "<|end_of_text|>"
},
- "trainer_settings": {
+ "trainer_settings": {
"evaluation_strategy": "steps",
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
diff --git a/configs/exp/train/reinforce/reinforce.json b/configs/exp/train/reinforce/reinforce.json
new file mode 100644
index 0000000..1588bfc
--- /dev/null
+++ b/configs/exp/train/reinforce/reinforce.json
@@ -0,0 +1,214 @@
+{
+ "train_dataset_settings": {
+ "sources": [
+ {
+ "name": "train_hh",
+ "records_path": "/from_s3/datasets/train_chat_filtered.jsonl",
+ "sample_rate": 1.0
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 1024,
+ "only_answer_loss": true,
+ "keep_end": true,
+ "random_cut": false,
+ "check_max_len_enforced": true
+ },
+ "val_dataset_settings": {
+ "sources": [
+ {
+ "name": "val_hh",
+ "records_path": "/from_s3/datasets/val_chat_filtered.jsonl",
+ "sample_rate": 0.2
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 1024,
+ "only_answer_loss": true,
+ "keep_end": true,
+ "random_cut": false,
+ "check_max_len_enforced": true
+ },
+
+ "cherry_pick_settings": {
+ "generator_transformers_settings": {
+ "num_beams": 1,
+ "do_sample": true,
+ "num_return_sequences": 3,
+ "max_new_tokens": 1024,
+ "temperature": 0.7,
+ "top_p": 1.0,
+ "top_k": 0,
+ "repetition_penalty": 1.0
+ },
+ "custom_generation_settings": {
+ "generation_eos_token": "<|eot_id|>",
+ "remove_prompt": true,
+ "skip_special_tokens": true
+ },
+ "dataset_settings": {
+ "sources": [
+ {
+ "name": "cp_hh",
+ "records_path": "/from_s3/datasets/val_chat_filtered.jsonl",
+ "num_samples": 10
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 1024,
+ "random_cut": false,
+ "keep_end": true,
+ "only_answer_loss": true,
+ "check_max_len_enforced": true
+ },
+ "metric_settings": [
+ {
+ "type": "length",
+ "parameters": {
+ "need_average": [true, false]
+ }
+ },
+ {
+ "type": "self_bleu",
+ "parameters": {
+ "need_average": [true, false]
+ }
+ },
+ {
+ "type": "dist_n",
+ "parameters": {
+ "need_average": [true, false]
+ }
+ },
+ {
+ "type": "diversity",
+ "parameters": {
+ "need_average": [true, false]
+ }
+ },
+ {
+ "type": "perplexity",
+ "parameters": {
+ "need_average": [true, false]
+ }
+ }
+ ]
+ },
+ "reward_model_settings": {
+ "model_path": "/from_s3/models/rm",
+ "model_type": "llama_reward_value",
+ "model_kwargs": {
+ "num_labels": 1,
+ "attn_implementation": "flash_attention_2"
+ },
+ "transformers_settings": {},
+ "check_checkpoint_is_compatible": true
+ },
+ "model_settings": {
+ "model_path": "/from_s3/models/policy/",
+ "model_type": "causal",
+ "transformers_settings": {},
+ "model_kwargs": {
+ "attn_implementation": "flash_attention_2"
+ },
+ "check_checkpoint_is_compatible": true
+ },
+ "tokenizer_settings": {
+ "use_fast": true,
+ "padding_side": "left"
+ },
+ "peft_settings": {
+ "inference_mode": false,
+ "r": 64,
+ "lora_alpha": 64,
+ "lora_dropout": 0.0,
+ "use_rslora": false,
+ "init_lora_weights": true,
+ "task_type": "CAUSAL_LM"
+ },
+ "trainer_settings": {
+ "evaluation_strategy": "steps",
+ "per_device_train_batch_size": 8,
+ "per_device_eval_batch_size": 8,
+ "gradient_accumulation_steps": 10,
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.95,
+ "adam_epsilon": 1e-5,
+ "eval_steps": 50,
+ "save_steps": 100,
+ "save_strategy": "steps",
+ "load_best_model_at_end": false,
+ "logging_steps": 1,
+ "learning_rate": 1e-5,
+ "max_steps": 1001,
+ "lr_scheduler_type": "constant_with_warmup",
+ "warmup_ratio": 0.03,
+ "fp16": false,
+ "bf16": true,
+ "optim": "adamw_torch",
+ "weight_decay": 0.0,
+ "max_grad_norm": 2,
+ "save_total_limit": 11,
+ "dataloader_num_workers": 12,
+ "deepspeed": "configs/exp/deepspeed/ds_config_stage_2_reinforce.json",
+
+ "max_response_length": 1024,
+ "stop_token": "<|eot_id|>",
+ "temperature": 0.7,
+ "non_eos_penalty": true,
+ "penalty_reward_value": -1,
+ "clip_rewards_min": -1e8,
+ "clip_rewards_max": 1e8,
+ "whiten_rewards": false,
+ "kl_coef": 0.05,
+ "risk_coef": 0.00,
+ "mean_baseline_coef": 0.95,
+
+ "init_kl_coef": null,
+ "target_kl": null,
+ "adaptive_kl_k": null,
+ "adaptive_kl_clip_value": null,
+ "min_kl_coef": null,
+ "max_kl_coef": null,
+
+ "num_samples_for_reward_stats": 1000
+
+ },
+ "wandb_settings": {
+ "project_name": "...",
+ "group": "hh-llama3-8b",
+ "job_type": "reinforce",
+ "run_name": "rvm-kl_0.05",
+ "entity": "...",
+ "tags": ["reinforce", "llama3-8b", "hh", "rvm"]
+ },
+ "log_path": "train_output/hh/llama3-8b/reinforce/rvm-kl_0.05-0.01/"
+}
+
diff --git a/tests/fixtures/configs/train/reinforce/base.json b/tests/fixtures/configs/train/reinforce/base.json
new file mode 100644
index 0000000..520ced7
--- /dev/null
+++ b/tests/fixtures/configs/train/reinforce/base.json
@@ -0,0 +1,175 @@
+{
+ "train_dataset_settings": {
+ "sources": [
+ {
+ "name": "train_chat",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "sample_rate": 1.0
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 2000,
+ "only_answer_loss": true
+ },
+ "val_dataset_settings": {
+ "sources": [
+ {
+ "name": "val_chat",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "sample_rate": 1.0
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 2000,
+ "only_answer_loss": true
+ },
+ "cherry_pick_settings": {
+ "generator_transformers_settings": {
+ "num_beams": 1,
+ "do_sample": false,
+ "stop_strings": "",
+ "max_new_tokens": 8
+ },
+ "custom_generation_settings": {
+ "skip_special_tokens": false
+ },
+ "dataset_settings": {
+ "sources": [
+ {
+ "name": "chat_test",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "num_samples": 2
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 150,
+ "only_answer_loss": true
+ },
+ "metric_settings": [
+ {
+ "type": "length",
+ "parameters": {
+ "need_average": [
+ true
+ ]
+ }
+ },
+ {
+ "type": "kl",
+ "parameters": {
+ "need_average": [
+ true
+ ],
+ "ref_logits_type": "sft"
+ }
+ },
+ {
+ "type": "kl",
+ "parameters": {
+ "need_average": [
+ true
+ ],
+ "ref_logits_type": "reference"
+ }
+ }
+ ]
+ },
+ "reward_model_settings": {
+ "model_path": "tests/fixtures/models/llama2_tiny",
+ "model_type": "seq_cls",
+ "model_kwargs": {
+ "num_labels": 1
+ },
+ "transformers_settings": {}
+ },
+ "model_settings": {
+ "model_path": "tests/fixtures/models/llama2_tiny",
+ "model_type": "causal",
+ "transformers_settings": {},
+ "embeddings_initialization_strategy": {
+ "<|start_header_id|>": "",
+ "<|end_header_id|>": "",
+ "<|eot_id|>": ""
+ }
+ },
+ "tokenizer_settings": {
+ "tokenizer_kwargs": {
+ "padding_side": "left"
+ }
+ },
+ "trainer_settings": {
+ "evaluation_strategy": "steps",
+ "per_device_train_batch_size": 8,
+ "per_device_eval_batch_size": 8,
+ "gradient_accumulation_steps": 10,
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.95,
+ "adam_epsilon": 0.00001,
+ "eval_steps": 50,
+ "save_steps": 100,
+ "save_strategy": "steps",
+ "load_best_model_at_end": false,
+ "logging_steps": 1,
+ "learning_rate": 0.00001,
+ "max_steps": 1001,
+ "lr_scheduler_type": "constant_with_warmup",
+ "warmup_ratio": 0.03,
+ "fp16": false,
+ "bf16": false,
+ "optim": "adamw_torch",
+ "weight_decay": 0.0,
+ "max_grad_norm": 2,
+ "save_total_limit": 11,
+ "dataloader_num_workers": 12,
+ "stop_token": "<|eot_id|>",
+ "temperature": 0.7,
+ "non_eos_penalty": true,
+ "penalty_reward_value": -1,
+ "clip_rewards_min": -1e+8,
+ "clip_rewards_max": 1e+8,
+ "whiten_rewards": false,
+ "kl_coef": 0.05,
+ "mean_baseline_coef": 0.95,
+ "num_samples_for_reward_stats": 1000,
+ "no_cuda": true
+ },
+ "special_tokens_settings": {
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": ""
+ },
+ "logging_settings": {
+ "project_name": "alignment",
+ "run_name": "sft",
+ "entity": "turbo-alignment"
+ },
+ "seed": 0,
+ "log_path": "train_output"
+}
diff --git a/tests/fixtures/configs/train/reinforce/gpt2_base.json b/tests/fixtures/configs/train/reinforce/gpt2_base.json
new file mode 100644
index 0000000..7560fb9
--- /dev/null
+++ b/tests/fixtures/configs/train/reinforce/gpt2_base.json
@@ -0,0 +1,175 @@
+{
+ "train_dataset_settings": {
+ "sources": [
+ {
+ "name": "train_chat",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "sample_rate": 1.0
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 2000,
+ "only_answer_loss": true
+ },
+ "val_dataset_settings": {
+ "sources": [
+ {
+ "name": "val_chat",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "sample_rate": 1.0
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 2000,
+ "only_answer_loss": true
+ },
+ "cherry_pick_settings": {
+ "generator_transformers_settings": {
+ "num_beams": 1,
+ "do_sample": false,
+ "stop_strings": "",
+ "max_new_tokens": 8
+ },
+ "custom_generation_settings": {
+ "skip_special_tokens": false
+ },
+ "dataset_settings": {
+ "sources": [
+ {
+ "name": "chat_test",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "num_samples": 2
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "assistant",
+ "user": "user",
+ "system": "system"
+ },
+ "prefix_template": "<|start_header_id|>{role}<|end_header_id|>\n\n",
+ "suffix_template": "<|eot_id|>"
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 150,
+ "only_answer_loss": true
+ },
+ "metric_settings": [
+ {
+ "type": "length",
+ "parameters": {
+ "need_average": [
+ true
+ ]
+ }
+ },
+ {
+ "type": "kl",
+ "parameters": {
+ "need_average": [
+ true
+ ],
+ "ref_logits_type": "sft"
+ }
+ },
+ {
+ "type": "kl",
+ "parameters": {
+ "need_average": [
+ true
+ ],
+ "ref_logits_type": "reference"
+ }
+ }
+ ]
+ },
+ "reward_model_settings": {
+ "model_path": "gpt2",
+ "model_type": "seq_cls",
+ "model_kwargs": {
+ "num_labels": 1
+ },
+ "transformers_settings": {}
+ },
+ "model_settings": {
+ "model_path": "gpt2",
+ "model_type": "causal",
+ "transformers_settings": {},
+ "embeddings_initialization_strategy": {
+ "<|start_header_id|>": "<|endoftext|>",
+ "<|end_header_id|>": "<|endoftext|>",
+ "<|eot_id|>": "<|endoftext|>"
+ }
+ },
+ "tokenizer_settings": {
+ "tokenizer_kwargs": {
+ "padding_side": "left"
+ }
+ },
+ "trainer_settings": {
+ "evaluation_strategy": "steps",
+ "per_device_train_batch_size": 1,
+ "per_device_eval_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.95,
+ "adam_epsilon": 0.00001,
+ "eval_steps": 50,
+ "save_steps": 100,
+ "save_strategy": "steps",
+ "load_best_model_at_end": false,
+ "logging_steps": 1,
+ "learning_rate": 0.00001,
+ "max_steps": 1001,
+ "lr_scheduler_type": "constant_with_warmup",
+ "warmup_ratio": 0.03,
+ "fp16": false,
+ "bf16": false,
+ "optim": "adamw_torch",
+ "weight_decay": 0.0,
+ "max_grad_norm": 2,
+ "save_total_limit": 11,
+ "dataloader_num_workers": 12,
+ "stop_token": "<|eot_id|>",
+ "temperature": 0.7,
+ "non_eos_penalty": true,
+ "penalty_reward_value": -1,
+ "clip_rewards_min": -1e+8,
+ "clip_rewards_max": 1e+8,
+ "whiten_rewards": false,
+ "kl_coef": 0.05,
+ "mean_baseline_coef": 0.95,
+ "num_samples_for_reward_stats": 10,
+ "no_cuda": true
+ },
+ "special_tokens_settings": {
+ "bos_token": "<|endoftext|>",
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|endoftext|>"
+ },
+ "logging_settings": {
+ "project_name": "alignment",
+ "run_name": "sft",
+ "entity": "turbo-alignment"
+ },
+ "seed": 0,
+ "log_path": "train_output"
+}
diff --git a/turbo_alignment/cherry_picks/multimodal.py b/turbo_alignment/cherry_picks/multimodal.py
index 80c6022..ba1c536 100755
--- a/turbo_alignment/cherry_picks/multimodal.py
+++ b/turbo_alignment/cherry_picks/multimodal.py
@@ -1,10 +1,10 @@
from typing import Iterable
import torch
+import wandb
from PIL import Image
from transformers import PreTrainedTokenizerBase
-import wandb
from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset
from turbo_alignment.generators.multimodal import MultimodalGenerator
diff --git a/turbo_alignment/cherry_picks/rm.py b/turbo_alignment/cherry_picks/rm.py
index 6cdcbcc..a86d7da 100755
--- a/turbo_alignment/cherry_picks/rm.py
+++ b/turbo_alignment/cherry_picks/rm.py
@@ -5,7 +5,7 @@
from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.pair_preferences import PairPreferenceDataset
-from turbo_alignment.generators.rm import RMPairGenerator
+from turbo_alignment.generators import RMSamplingGenerator
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.settings.cherry_pick import RMCherryPickSettings
from turbo_alignment.settings.metric import ElementWiseScores, MetricResults
@@ -29,7 +29,8 @@ def _get_dataset_metrics(
) -> list[MetricResults]:
accelerator: Accelerator = kwargs.get('accelerator', None)
- generator = RMPairGenerator(
+ # FIXME: convert paired dataset to sampling dataset
+ generator = RMSamplingGenerator(
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
diff --git a/turbo_alignment/cli/train.py b/turbo_alignment/cli/train.py
index da109ff..aa522ad 100755
--- a/turbo_alignment/cli/train.py
+++ b/turbo_alignment/cli/train.py
@@ -14,7 +14,7 @@ def train_sft_entrypoint(
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
- )
+ ),
) -> None:
experiment_settings = pipeline_settings.SftTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainSFTStrategy().run(experiment_settings)
@@ -27,7 +27,7 @@ def train_dpo_entrypoint(
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
- )
+ ),
) -> None:
experiment_settings = pipeline_settings.DPOTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainDPOStrategy().run(experiment_settings)
@@ -40,7 +40,7 @@ def train_kto_entrypoint(
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
- )
+ ),
) -> None:
experiment_settings = pipeline_settings.KTOTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainKTOStrategy().run(experiment_settings)
@@ -53,7 +53,7 @@ def train_ddpo_entrypoint(
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
- )
+ ),
) -> None:
experiment_settings = pipeline_settings.DDPOTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainDDPOStrategy().run(experiment_settings)
@@ -66,7 +66,7 @@ def train_rm_entrypoint(
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
- )
+ ),
) -> None:
experiment_settings = pipeline_settings.RMTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainRMStrategy().run(experiment_settings)
@@ -79,7 +79,7 @@ def classification_training(
'--experiment_settings_path',
exists=True,
help='Path to experiment config file',
- )
+ ),
) -> None:
experiment_settings = pipeline_settings.ClassificationTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainClassificationStrategy().run(experiment_settings)
@@ -88,8 +88,11 @@ def classification_training(
@app.command(name='train_multimodal', help='Train Multimodal')
def multimodal_training(
experiment_settings_path: Path = typer.Option(
- ..., '--experiment_settings_path', exists=True, help='Path to experiment config file'
- )
+ ...,
+ '--experiment_settings_path',
+ exists=True,
+ help='Path to experiment config file',
+ ),
) -> None:
experiment_settings = pipeline_settings.MultimodalTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainMultimodalStrategy().run(experiment_settings)
@@ -98,8 +101,24 @@ def multimodal_training(
@app.command(name='train_rag', help='Train RAG')
def rag_training(
experiment_settings_path: Path = typer.Option(
- ..., '--experiment_settings_path', exists=True, help='Path to experiment config file'
- )
+ ...,
+ '--experiment_settings_path',
+ exists=True,
+ help='Path to experiment config file',
+ ),
) -> None:
experiment_settings = pipeline_settings.RAGTrainExperimentSettings.parse_file(experiment_settings_path)
pipelines.TrainRAGStrategy().run(experiment_settings)
+
+
+@app.command(name='train_reinforce', help='Train REINFORCE pipeline')
+def reinforce_training(
+ experiment_settings_path: Path = typer.Option(
+ ...,
+ '--experiment_settings_path',
+ exists=True,
+ help='Path to experiment config file',
+ ),
+) -> None:
+ experiment_settings = pipeline_settings.REINFORCETrainExperimentSettings.parse_file(experiment_settings_path)
+ pipelines.TrainREINFORCEStrategy().run(experiment_settings)
diff --git a/turbo_alignment/common/data/multimodal/audio/imagebind.py b/turbo_alignment/common/data/multimodal/audio/imagebind.py
index c27503b..e9da72b 100644
--- a/turbo_alignment/common/data/multimodal/audio/imagebind.py
+++ b/turbo_alignment/common/data/multimodal/audio/imagebind.py
@@ -1,7 +1,12 @@
from pathlib import Path
import torch
-import torchaudio
+
+# FIXME
+try:
+ import torchaudio
+except:
+ pass
from loguru import logger
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
from torchvision import transforms
diff --git a/turbo_alignment/common/distributed.py b/turbo_alignment/common/distributed.py
new file mode 100644
index 0000000..a81c3cf
--- /dev/null
+++ b/turbo_alignment/common/distributed.py
@@ -0,0 +1,128 @@
+from datetime import timedelta
+from typing import Any, Literal
+
+import numpy as np
+import torch
+import os
+import torch.distributed
+from torch.distributed.distributed_c10d import (
+ Backend,
+ PrefixStore,
+ Store,
+ _new_process_group_helper,
+ _world,
+ default_pg_timeout,
+ rendezvous,
+)
+
+
+world_size = int(os.getenv("WORLD_SIZE", "1"))
+
+
+def get_global_mean(values: torch.Tensor) -> float:
+ # Calculate the mean reward for the current process
+ local_sum = values.sum().item()
+
+ if world_size == 1:
+ return values.mean().item()
+
+ num_rewards = torch.tensor(len(values), device=values.device)
+
+ # Create a tensor to hold the global mean reward
+ global_sum = torch.tensor(local_sum, device=values.device)
+
+ # Collect mean rewards from all processes
+ torch.distributed.all_reduce(global_sum, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(num_rewards, op=torch.distributed.ReduceOp.SUM)
+ global_mean = (global_sum / num_rewards).item()
+
+ return global_mean
+
+
+def get_global_std(values: torch.Tensor, mean: float) -> float:
+ local_sum = ((values - mean) ** 2).sum().item()
+
+ if world_size == 1:
+ return local_sum
+
+ # Calculate the mean reward for the current process
+ num_rewards = torch.tensor(len(values), device=values.device)
+
+ # Create a tensor to hold the global mean reward
+ global_sum = torch.tensor(local_sum, device=values.device)
+
+ # Collect mean rewards from all processes
+ torch.distributed.all_reduce(global_sum, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(num_rewards, op=torch.distributed.ReduceOp.SUM)
+ global_mean = np.sqrt((global_sum / (num_rewards - 1)).item())
+
+ return global_mean
+
+
+def get_log_mean_std(
+ tensor: torch.Tensor, name: str, train_eval: Literal['train', 'eval'] | None = None
+) -> dict[str, float]:
+ mean = get_global_mean(tensor)
+ metrics = {}
+ if train_eval is not None:
+ metrics[f'{train_eval}/{name}_mean'] = mean
+ metrics[f'{train_eval}/{name}_std'] = get_global_std(tensor, mean=mean)
+ else:
+ metrics[f'{name}_mean'] = mean
+ metrics[f'{name}_std'] = get_global_std(tensor, mean=mean)
+
+ return metrics
+
+
+# Copy from pytorch to allow creating multiple main groups.
+# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
+def init_process_group(
+ backend: str | Backend | None = None,
+ init_method: str | None = None,
+ timeout: timedelta | None = None,
+ world_size: int = -1,
+ rank: int = -1,
+ store: Store | None = None,
+ group_name: str | None = None,
+ pg_options: Any = None,
+):
+ assert (store is None) or (init_method is None), 'Cannot specify both init_method and store.'
+
+ if store is not None:
+ assert world_size > 0, 'world_size must be positive if using store'
+ assert rank >= 0, 'rank must be non-negative if using store'
+ elif init_method is None:
+ init_method = 'env://'
+
+ if backend:
+ backend = Backend(backend)
+ else:
+ backend = Backend('undefined')
+
+ if timeout is None:
+ timeout = default_pg_timeout
+
+ # backward compatible API
+ if store is None:
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+
+ # Use a PrefixStore to avoid accidental overrides of keys used by
+ # different systems (e.g. RPC) in case the store is multi-tenant.
+ store = PrefixStore(group_name, store)
+
+ pg, _ = _new_process_group_helper(
+ world_size,
+ rank,
+ [],
+ backend,
+ store,
+ group_name=group_name,
+ pg_options=pg_options,
+ timeout=timeout,
+ )
+
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
+
+ return pg
diff --git a/turbo_alignment/common/logging/weights_and_biases.py b/turbo_alignment/common/logging/weights_and_biases.py
index 112e585..3b5bdfd 100755
--- a/turbo_alignment/common/logging/weights_and_biases.py
+++ b/turbo_alignment/common/logging/weights_and_biases.py
@@ -1,9 +1,9 @@
from typing import Any
+import wandb
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run
-import wandb
from turbo_alignment.settings.logging.weights_and_biases import WandbSettings
diff --git a/turbo_alignment/common/tf/callbacks/logging.py b/turbo_alignment/common/tf/callbacks/logging.py
index 1ad484d..646e280 100755
--- a/turbo_alignment/common/tf/callbacks/logging.py
+++ b/turbo_alignment/common/tf/callbacks/logging.py
@@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
+import wandb
from clearml import Task
from transformers import (
TrainerCallback,
@@ -13,7 +14,6 @@
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run
-import wandb
from turbo_alignment.common.logging import get_project_logger
logger = get_project_logger()
diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py
index 5415daf..86ddcfd 100755
--- a/turbo_alignment/generators/base.py
+++ b/turbo_alignment/generators/base.py
@@ -4,7 +4,12 @@
import torch
from accelerate import Accelerator
-from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
+from transformers import (
+ BatchEncoding,
+ GenerationConfig,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+)
from turbo_alignment.dataset.base import BaseDataset
from turbo_alignment.dataset.base.models import DatasetRecord
@@ -34,11 +39,11 @@ def __init__(
self._batch = batch
@abstractmethod
- def _generate_from_batch(
+ def generate_from_batch(
self,
- records: list[dict[str, Any]],
- original_records: list[DatasetRecordT],
dataset_name: str,
+ records: list[dict[str, Any]],
+ original_records: list[DatasetRecordT] | None = None,
) -> list[InferenceOutputT]:
...
@@ -64,10 +69,10 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]:
zip(records_batches, original_records_batches)
):
generations.append(
- self._generate_from_batch(
+ self.generate_from_batch(
+ dataset.source.name,
records_batch,
original_records_batch,
- dataset.source.name,
)
)
else:
@@ -75,10 +80,10 @@ def generate_from_dataset(self, dataset: BaseDataset) -> list[InferenceOutputT]:
list(zip(records_batches, original_records_batches)), apply_padding=True
) as accelerator_records:
generations = [
- self._generate_from_batch(
+ self.generate_from_batch(
+ dataset.source.name,
records_batch,
original_records_batch,
- dataset.source.name,
)
for records_batch, original_records_batch in accelerator_records
][: len(records_batches)]
@@ -92,13 +97,10 @@ def __init__(
transformers_settings: GeneratorTransformersSettings,
custom_generation_settings: CustomChatGenerationSettings,
tokenizer: PreTrainedTokenizerBase,
- return_logits: bool = False,
**kwargs,
) -> None:
super().__init__(tokenizer=tokenizer, **kwargs)
- self._return_logits = return_logits
-
self._transformers_generator_parameters = GenerationConfig(
bos_token_id=self._tokenizer.bos_token_id,
**transformers_settings.dict(),
@@ -107,25 +109,28 @@ def __init__(
self._custom_generation_settings = custom_generation_settings
@abstractmethod
- def _generate_from_single_record(
+ def generate_from_single_record(
self,
- record: dict[str, torch.Tensor],
- original_record: DatasetRecordT,
dataset_name: str,
+ record: dict[str, torch.Tensor],
+ original_record: DatasetRecordT | None = None,
) -> InferenceOutputT:
...
@abstractmethod
- def _generate_from_batch_records(
+ def generate_from_batch_records(
self,
- records: list[dict[str, torch.Tensor]],
- original_records: list[DatasetRecordT],
dataset_name: str,
+ records_batch: dict[str, torch.Tensor] | BatchEncoding,
+ original_records: list[DatasetRecordT] | None = None,
) -> list[InferenceOutputT]:
...
- def _generate_from_batch(
- self, records: list[dict[str, Any]], original_records: list[DatasetRecordT], dataset_name: str
+ def generate_from_batch(
+ self,
+ dataset_name: str,
+ records: list[dict[str, Any]],
+ original_records: list[DatasetRecordT] | None = None,
) -> list[InferenceOutputT]:
if self._custom_generation_settings.batch > 1:
if self._transformers_generator_parameters.num_beams != 1:
@@ -134,18 +139,52 @@ def _generate_from_batch(
self._tokenizer.padding_side = 'left'
self._tokenizer.pad_token_id = self._tokenizer.pad_token_id
- return self._generate_from_batch_records(records, original_records, dataset_name)
+ input_ids = [record['input_ids'].tolist() for record in records]
+ attention_mask = [record['attention_mask'].tolist() for record in records]
+
+ max_input_length = max(len(sample) for sample in input_ids)
+
+ records_batch = self._tokenizer.pad(
+ {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ },
+ padding='max_length',
+ max_length=max_input_length,
+ return_tensors='pt',
+ )
+
+ return self.generate_from_batch_records(dataset_name, records_batch, original_records)
return [
- self._generate_from_single_record(record, original_record, dataset_name)
+ self.generate_from_single_record(dataset_name, record, original_record)
for record, original_record in zip(records, original_records)
]
- @staticmethod
- def _postprocess(input_indices: torch.Tensor, output_indices: torch.Tensor, remove_prompt: bool) -> torch.Tensor:
+ def _postprocess(
+ self,
+ input_indices: torch.Tensor,
+ output_indices: torch.Tensor,
+ logits: torch.Tensor | None,
+ remove_prompt: bool,
+ only_answer_logits: bool,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, list[str]]:
+ processed_logits: torch.Tensor | None = None
+ processed_output_indices = output_indices.cpu()
+
+ if logits is not None:
+ processed_logits = logits.cpu()
+ if only_answer_logits:
+ processed_logits = logits[:, input_indices.shape[1] :, :].cpu()
+
if remove_prompt:
- return output_indices[:, input_indices.shape[1] :].cpu()
- return output_indices.cpu()
+ processed_output_indices = output_indices[:, input_indices.shape[1] :]
+
+ answers = self._decode(token_indices=processed_output_indices)
+
+ answers_attention_mask = processed_output_indices != self._tokenizer.pad_token_id
+
+ return processed_output_indices, answers_attention_mask, processed_logits, answers
def _decode(self, token_indices: torch.Tensor) -> list[str]:
return self._tokenizer.batch_decode(
diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py
index acc4337..579491c 100755
--- a/turbo_alignment/generators/chat.py
+++ b/turbo_alignment/generators/chat.py
@@ -1,6 +1,7 @@
from typing import Any
import torch
+from transformers import BatchEncoding
from turbo_alignment.dataset.chat.models import ChatDatasetRecord
from turbo_alignment.generators.base import ChatGeneratorBase
@@ -11,30 +12,16 @@
class ChatGenerator(ChatGeneratorBase[ChatDatasetRecord, ChatInferenceOutput]):
- def _generate_from_batch_records(
+ def generate_from_batch_records(
self,
- records: list[dict[str, torch.Tensor]],
- original_records: list[ChatDatasetRecord],
dataset_name: str,
+ records_batch: dict[str, torch.Tensor] | BatchEncoding,
+ original_records: list[ChatDatasetRecord] | None = None,
) -> list[ChatInferenceOutput]:
- input_ids = [record['input_ids'].tolist() for record in records]
- attention_mask = [record['attention_mask'].tolist() for record in records]
-
- max_input_length = max(len(sample) for sample in input_ids)
-
- batch = self._tokenizer.pad(
- {
- 'input_ids': input_ids,
- 'attention_mask': attention_mask,
- },
- padding='max_length',
- max_length=max_input_length,
- return_tensors='pt',
- )
-
- batched_input_ids = batch['input_ids'].to(self.device)
- batched_attention_mask = batch['attention_mask'].to(self.device)
+ batched_input_ids = records_batch['input_ids'].to(self.device)
+ batched_attention_mask = records_batch['attention_mask'].to(self.device)
+ assert self._tokenizer.padding_side == 'left'
output_indices = self._model.generate(
inputs=batched_input_ids,
attention_mask=batched_attention_mask,
@@ -43,39 +30,60 @@ def _generate_from_batch_records(
pad_token_id=self._tokenizer.pad_token_id,
)
- postprocessed_output_indices = self._postprocess(
+ logits: torch.Tensor | None = None
+
+ if self._custom_generation_settings.return_logits:
+ with torch.no_grad():
+ logits = self._model(output_indices).logits
+
+ postprocessed_output_indices, answers_attention_mask, postprocessed_logits, answers = self._postprocess(
input_indices=batched_input_ids,
output_indices=output_indices,
+ logits=logits,
remove_prompt=self._custom_generation_settings.remove_prompt,
+ only_answer_logits=self._custom_generation_settings.only_answer_logits,
)
- answers = self._decode(token_indices=postprocessed_output_indices)
-
- return [
- ChatInferenceOutput(
- id=original_record.id,
- dataset_name=dataset_name,
- messages=original_record.messages,
- label=original_record.label,
- meta=original_record.meta,
- answers=[
- AnswerMessage(
- id='0',
- content=answer,
- input_token_ids=None,
- answer_token_ids=None,
- logits=None,
+ outputs = []
+ for i, (input_ids, attention_mask) in enumerate(zip(batched_input_ids, batched_attention_mask)):
+ ans_batch_start = i * self._transformers_generator_parameters.num_return_sequences
+ ans_batch_end = (i + 1) * self._transformers_generator_parameters.num_return_sequences
+ batch_answer_tokens = postprocessed_output_indices[ans_batch_start:ans_batch_end, :]
+ batch_answer_attention_masks = answers_attention_mask[ans_batch_start:ans_batch_end, :]
+ batch_answers = answers[ans_batch_start:ans_batch_end]
+
+ for j, (answer, answer_tokens, answer_attention_mask) in enumerate(
+ zip(batch_answers, batch_answer_tokens, batch_answer_attention_masks)
+ ):
+ answer_logits = None if postprocessed_logits is None else postprocessed_logits[i + j, :, :]
+
+ outputs.append(
+ ChatInferenceOutput(
+ dataset_name=dataset_name,
+ messages=original_records[i].messages if original_records else None,
+ label=original_records[i].label if original_records else None,
+ meta=original_records[i].meta if original_records else None,
+ input_token_ids=input_ids,
+ input_attention_mask=attention_mask,
+ answers=[
+ AnswerMessage(
+ id=str(i),
+ content=answer,
+ answer_token_ids=answer_tokens,
+ logits=answer_logits,
+ answer_attention_mask=answer_attention_mask,
+ )
+ ],
)
- ],
- )
- for original_record, answer in zip(original_records, answers)
- ]
+ )
+
+ return outputs
- def _generate_from_single_record(
+ def generate_from_single_record(
self,
- record: dict[str, Any],
- original_record: ChatDatasetRecord,
dataset_name: str,
+ record: dict[str, Any],
+ original_record: ChatDatasetRecord | None = None,
) -> ChatInferenceOutput:
input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device)
attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self.device)
@@ -88,52 +96,41 @@ def _generate_from_single_record(
pad_token_id=self._tokenizer.pad_token_id,
)
- postprocessed_output_indices = self._postprocess(
- input_indices=input_ids,
- output_indices=output_indices,
- remove_prompt=self._custom_generation_settings.remove_prompt,
- )
-
- answers = self._decode(token_indices=postprocessed_output_indices)
-
logits: torch.Tensor | None = None
- input_token_ids: torch.Tensor | None = None
- answer_tokens_ids: torch.Tensor | None = None
- if self._return_logits:
+ if self._custom_generation_settings.return_logits:
with torch.no_grad():
logits = self._model(output_indices).logits.cpu()
- answer_tokens_ids = postprocessed_output_indices
- input_token_ids = input_ids
+ postprocessed_output_indices, answers_attention_mask, postprocessed_logits, answers = self._postprocess(
+ input_indices=input_ids,
+ output_indices=output_indices,
+ logits=logits,
+ remove_prompt=self._custom_generation_settings.remove_prompt,
+ only_answer_logits=self._custom_generation_settings.only_answer_logits,
+ )
- answer_messages = [
- AnswerMessage(
- id=str(i),
- content=a,
- input_token_ids=input_token_ids,
- answer_token_ids=a_t_ids.unsqueeze(0),
- logits=l.unsqueeze(0),
- )
- for i, (a, a_t_ids, l) in enumerate(zip(answers, answer_tokens_ids, logits)) # type: ignore[arg-type]
- ]
- else:
- answer_messages = [
+ answer_messages = []
+ for i, (answer, answer_tokens, answer_attention_mask) in enumerate(
+ zip(answers, postprocessed_output_indices, answers_attention_mask)
+ ):
+ answer_logits = None if postprocessed_logits is None else postprocessed_logits[i, :, :].unsqueeze(0)
+ answer_messages.append(
AnswerMessage(
id=str(i),
- content=a,
- input_token_ids=input_token_ids,
- answer_token_ids=answer_tokens_ids,
- logits=logits,
+ content=answer,
+ answer_token_ids=answer_tokens.unsqueeze(0),
+ answer_attention_mask=answer_attention_mask,
+ logits=answer_logits,
)
- for i, a in enumerate(answers)
- ]
+ )
return ChatInferenceOutput(
- id=original_record.id,
dataset_name=dataset_name,
messages=original_record.messages,
label=original_record.label,
meta=original_record.meta,
answers=answer_messages,
+ input_token_ids=input_ids,
+ input_attention_mask=attention_mask,
)
diff --git a/turbo_alignment/generators/classification.py b/turbo_alignment/generators/classification.py
index e809e16..1bca7a9 100755
--- a/turbo_alignment/generators/classification.py
+++ b/turbo_alignment/generators/classification.py
@@ -16,8 +16,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs):
self._collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)
- def _generate_from_batch(
- self, records: list[dict[str, Any]], original_records: list[ClassificationDatasetRecord], dataset_name: str
+ def generate_from_batch(
+ self,
+ dataset_name: str,
+ records: list[dict[str, Any]],
+ original_records: list[ClassificationDatasetRecord] | None = None,
) -> list[ClassificationInferenceOutput]:
inputs = [{'input_ids': record['input_ids']} for record in records]
diff --git a/turbo_alignment/generators/multimodal.py b/turbo_alignment/generators/multimodal.py
index 0c64137..731b3f9 100755
--- a/turbo_alignment/generators/multimodal.py
+++ b/turbo_alignment/generators/multimodal.py
@@ -1,5 +1,5 @@
import torch
-from transformers import PreTrainedTokenizerBase
+from transformers import BatchEncoding, PreTrainedTokenizerBase
from turbo_alignment.dataset.multimodal.models import MultimodalDatasetRecord
from turbo_alignment.generators.base import ChatGeneratorBase
@@ -26,19 +26,19 @@ def __init__(
**kwargs,
)
- def _generate_from_batch_records(
+ def generate_from_batch_records(
self,
- records: list[dict[str, torch.Tensor]],
- original_records: list[MultimodalDatasetRecord],
dataset_name: str,
+ records_batch: dict[str, torch.Tensor] | BatchEncoding,
+ original_records: list[MultimodalDatasetRecord] | None = None,
) -> list[MultimodalInferenceOutput]:
raise ValueError('You can not use batch generation with multimodаl generator')
- def _generate_from_single_record(
+ def generate_from_single_record(
self,
- record: dict[str, torch.Tensor],
- original_record: MultimodalDatasetRecord,
dataset_name: str,
+ record: dict[str, torch.Tensor],
+ original_record: MultimodalDatasetRecord | None = None,
) -> MultimodalInferenceOutput:
input_ids = torch.unsqueeze(record['input_ids'], 0).to(self._model.language_model.device)
attention_mask = torch.unsqueeze(record['attention_mask'], 0).to(self._model.language_model.device)
@@ -60,7 +60,7 @@ def _generate_from_single_record(
input_token_ids: torch.Tensor | None = None
answer_tokens_ids: torch.Tensor | None = None
- if self._return_logits:
+ if self._custom_generation_settings.return_logits:
with torch.no_grad():
logits = self._model(output_indices).logits
@@ -72,11 +72,11 @@ def _generate_from_single_record(
dataset_name=dataset_name,
messages=original_record.messages,
label=original_record.label,
+ input_token_ids=input_token_ids,
answers=[
AnswerMessage(
id=str(i),
content=a,
- input_token_ids=input_token_ids,
answer_token_ids=answer_tokens_ids,
logits=logits,
)
diff --git a/turbo_alignment/generators/rag.py b/turbo_alignment/generators/rag.py
index 4203de8..25dd42e 100755
--- a/turbo_alignment/generators/rag.py
+++ b/turbo_alignment/generators/rag.py
@@ -1,28 +1,35 @@
-from typing import Any
-
import torch
+from transformers import BatchEncoding
from turbo_alignment.dataset.chat.models import ChatDatasetRecord
-from turbo_alignment.generators.chat import ChatGenerator
+from turbo_alignment.generators.chat import ChatGeneratorBase
from turbo_alignment.settings.generators.outputs.chat import (
AnswerMessage,
RagInferenceOutput,
)
-class RagGenerator(ChatGenerator):
- def _generate_from_single_record(
+class RagGenerator(ChatGeneratorBase[ChatDatasetRecord, RagInferenceOutput]):
+ def generate_from_batch_records(
+ self,
+ dataset_name: str,
+ records_batch: dict[str, torch.Tensor] | BatchEncoding,
+ original_records: list[ChatDatasetRecord] | None = None,
+ ) -> list[RagInferenceOutput]:
+ raise ValueError('You can not use batch generation with RAG generator')
+
+ def generate_from_single_record(
self,
- record: dict[str, Any],
- original_record: ChatDatasetRecord,
dataset_name: str,
+ record: dict[str, torch.Tensor],
+ original_record: ChatDatasetRecord | None = None,
) -> RagInferenceOutput:
input_ids = torch.unsqueeze(record['input_ids'], 0).to(self.device)
answer_indices, document_indices, doc_scores = self._model.generate(
inputs=input_ids,
generation_config=self._transformers_generator_parameters,
- tokenizer=self._tokenizer.current_tokenizer,
+ tokenizer=self._tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
)
diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py
index ce9a9c8..b9d3892 100755
--- a/turbo_alignment/generators/rm.py
+++ b/turbo_alignment/generators/rm.py
@@ -1,95 +1,69 @@
from typing import Any
import torch
-from torch import nn
-from transformers import DataCollatorWithPadding, PreTrainedTokenizerBase
+from transformers import BatchEncoding, DataCollatorWithPadding, PreTrainedTokenizerBase
-from turbo_alignment.dataset.pair_preferences import PairPreferenceRecord
from turbo_alignment.dataset.sampling.models import SamplingDatasetRecord
from turbo_alignment.generators.base import BaseGenerator
-from turbo_alignment.settings.generators.outputs.rm import (
- RMPairInferenceOutput,
- RMSamplingInferenceOutput,
-)
-
-
-class RMPairGenerator(BaseGenerator[PairPreferenceRecord, RMPairInferenceOutput]):
- def __init__(self, tokenizer: PreTrainedTokenizerBase, **kwargs):
- self._collator = DataCollatorWithPadding(tokenizer=tokenizer)
-
- super().__init__(tokenizer=tokenizer, **kwargs)
-
- def _generate_from_batch(
- self, records: list[dict[str, Any]], original_records: list[PairPreferenceRecord], dataset_name: str
- ) -> list[RMPairInferenceOutput]:
- merged_inputs = [r['inputs_w'] for r in records] + [r['inputs_l'] for r in records]
- batch = self._collator(merged_inputs)
- input_ids = batch['input_ids'].to(self.device)
- attn_mask = batch['attention_mask'].to(self.device)
-
- with torch.no_grad():
- rewards = self._model(input_ids=input_ids, attention_mask=attn_mask).logits.cpu()
- rewards_w, rewards_l = rewards[: len(records)], rewards[len(records) :]
-
- return [
- RMPairInferenceOutput(
- id=record.id,
- context=record.context,
- answer_w=record.answer_w,
- answer_l=record.answer_l,
- reward_w=reward_w.item(),
- reward_l=reward_l.item(),
- dataset_name=dataset_name,
- )
- for record, reward_w, reward_l in zip(original_records, rewards_w, rewards_l)
- ]
+from turbo_alignment.settings.generators.outputs.rm import RMSamplingInferenceOutput
class RMSamplingGenerator(BaseGenerator[SamplingDatasetRecord, RMSamplingInferenceOutput]):
- def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int, **kwargs):
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, micro_batch: int = 1, **kwargs):
self._collator = DataCollatorWithPadding(tokenizer=tokenizer)
self._micro_batch = micro_batch
super().__init__(tokenizer=tokenizer, **kwargs)
- def _generate_from_batch(
- self, records: list[dict[str, Any]], original_records: list[SamplingDatasetRecord], dataset_name: str
+ def generate_from_batch_records(self, records: dict[str, torch.Tensor] | BatchEncoding) -> torch.Tensor:
+ with torch.no_grad():
+ rewards = self._model(**records).logits.cpu()
+ return rewards # .squeeze()
+
+ def generate_from_batch(
+ self,
+ dataset_name: str,
+ records: list[dict[str, Any]],
+ original_records: list[SamplingDatasetRecord] | None = None,
) -> list[RMSamplingInferenceOutput]:
+ self._tokenizer.padding_side = 'left'
+ self._tokenizer.pad_token_id = self._tokenizer.pad_token_id
+
merged_inputs = [inputs for record in records for key, inputs in record['answers'].items()]
- if len(merged_inputs) == 0:
- return []
+ input_ids = [record['input_ids'].tolist() for record in merged_inputs]
+ attention_mask = [record['attention_mask'].tolist() for record in merged_inputs]
rewards = []
- with torch.no_grad():
- input_ids = nn.utils.rnn.pad_sequence(
- [item['input_ids'] for item in merged_inputs],
- padding_value=self._tokenizer.pad_token_id,
- batch_first=True,
+ for i in range(0, len(input_ids), self._micro_batch):
+ input_ids_batch = input_ids[i : i + self._micro_batch]
+ attn_mask_batch = attention_mask[i : i + self._micro_batch]
+
+ max_input_length = max(len(sample) for sample in input_ids_batch)
+
+ records_batch = self._tokenizer.pad(
+ {
+ 'input_ids': input_ids_batch,
+ 'attention_mask': attn_mask_batch,
+ },
+ padding='max_length',
+ max_length=max_input_length,
+ return_tensors='pt',
+ ).to(self.device)
+
+ rewards.extend(self.generate_from_batch_records(records_batch).tolist())
+
+ outputs = []
+ for i, record in enumerate(original_records):
+ record_rewards = {answer.id: rewards[i + j] for j, answer in enumerate(record.answers)}
+
+ outputs.append(
+ RMSamplingInferenceOutput(
+ id=record.id,
+ answers=record.answers,
+ dataset_name=record.dataset_name,
+ messages=record.messages,
+ rewards=record_rewards,
+ )
)
- attn_mask = nn.utils.rnn.pad_sequence(
- [item['attention_mask'] for item in merged_inputs],
- padding_value=0,
- batch_first=True,
- )
- for i in range(0, len(input_ids), self._micro_batch):
- input_ids_batch = input_ids[i : i + self._micro_batch].to(self.device)
- attn_mask_batch = attn_mask[i : i + self._micro_batch].to(self.device)
- rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits.cpu())
-
- rewards = torch.cat(rewards, dim=0)
- record_rewards = [
- {answer_id: reward.item() for answer_id, reward in zip(record['answers'].keys(), rewards)}
- for record in records
- ]
-
- return [
- RMSamplingInferenceOutput(
- id=record.id,
- rewards=rewards,
- messages=record.messages,
- dataset_name=dataset_name,
- answers=record.answers,
- )
- for record, rewards in zip(original_records, record_rewards)
- ]
+ return outputs
diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py
index ff52de3..03dfd22 100755
--- a/turbo_alignment/generators/vllm_chat.py
+++ b/turbo_alignment/generators/vllm_chat.py
@@ -2,7 +2,11 @@
import torch
from transformers import PreTrainedTokenizerBase
-from vllm import LLM, SamplingParams
+
+try:
+ from vllm import LLM, SamplingParams
+except ImportError as exc:
+ raise exc
from turbo_alignment.dataset.chat import ChatDatasetRecord
from turbo_alignment.generators.base import BaseGenerator
@@ -22,7 +26,6 @@ def __init__(
model: LLM,
tokenizer: PreTrainedTokenizerBase,
batch: int,
- return_logits: bool = False,
):
model.set_tokenizer(tokenizer)
super().__init__(model, tokenizer, batch=batch)
@@ -52,10 +55,13 @@ def __init__(
**beam_search_params,
)
- self._return_logits = return_logits
+ self._custom_generation_settings = custom_generation_settings
- def _generate_from_batch(
- self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str
+ def generate_from_batch(
+ self,
+ dataset_name: str,
+ records: list[dict[str, Any]],
+ original_records: list[ChatDatasetRecord] | None = None,
) -> list[ChatInferenceOutput]:
input_ids = [record['input_ids'].tolist() for record in records]
request_outputs = self._model.generate(
@@ -74,7 +80,7 @@ def _generate_from_batch(
content=a.text,
sequence_score=a.cumulative_logprob,
)
- if self._return_logits:
+ if self._custom_generation_settings.return_logits:
ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0)
ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0)
diff --git a/turbo_alignment/pipelines/inference/chat.py b/turbo_alignment/pipelines/inference/chat.py
index a40df0f..0a3a051 100755
--- a/turbo_alignment/pipelines/inference/chat.py
+++ b/turbo_alignment/pipelines/inference/chat.py
@@ -26,7 +26,10 @@ def _get_single_inference_settings(
)
if model_inference_settings.use_vllm:
- import vllm
+ try:
+ import vllm
+ except ImportError as exc:
+ raise exc
from turbo_alignment.generators.vllm_chat import VLLMChatGenerator
diff --git a/turbo_alignment/pipelines/train/__init__.py b/turbo_alignment/pipelines/train/__init__.py
index ac24318..9511009 100755
--- a/turbo_alignment/pipelines/train/__init__.py
+++ b/turbo_alignment/pipelines/train/__init__.py
@@ -4,5 +4,6 @@
from .kto import TrainKTOStrategy
from .multimodal import TrainMultimodalStrategy
from .rag import TrainRAGStrategy
+from .reinforce import TrainREINFORCEStrategy
from .rm import TrainRMStrategy
from .sft import TrainSFTStrategy
diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py
index 0f98602..f3c0b20 100755
--- a/turbo_alignment/pipelines/train/base.py
+++ b/turbo_alignment/pipelines/train/base.py
@@ -136,6 +136,7 @@ def _add_trainer_callbacks(self, experiment_settings: ExperimentSettingsT, **kwa
def run(self, experiment_settings: ExperimentSettingsT) -> None:
training_args = self._get_training_args(experiment_settings)
+ print('HERE!!!!!!!!!!!!!!!!!!!!!!!!')
self.tokenizer = self._load_tokenizer(experiment_settings)
logger.info('Tokenizer is loaded!')
diff --git a/turbo_alignment/pipelines/train/reinforce.py b/turbo_alignment/pipelines/train/reinforce.py
new file mode 100644
index 0000000..ef04fe3
--- /dev/null
+++ b/turbo_alignment/pipelines/train/reinforce.py
@@ -0,0 +1,132 @@
+from typing import Callable
+
+from torch.utils.data import ConcatDataset, Dataset
+from transformers import PreTrainedModel, PreTrainedTokenizerBase
+from transformers.data.data_collator import (
+ DataCollatorForTokenClassification,
+ DataCollatorMixin,
+)
+
+from turbo_alignment.cherry_picks.chat import ChatCherryPickCallback
+from turbo_alignment.common.logging import get_project_logger
+from turbo_alignment.common.tf.loaders import load_model
+from turbo_alignment.constants import TRAINER_LOGS_FOLDER
+from turbo_alignment.dataset.chat.chat import InferenceChatDataset
+from turbo_alignment.dataset.loader import DatasetLoader
+from turbo_alignment.metrics.metric import Metric
+from turbo_alignment.metrics.registry import MetricSettingsRegistry
+from turbo_alignment.pipelines.train.base import BaseTrainStrategy
+from turbo_alignment.settings.datasets.base import DatasetStrategy
+from turbo_alignment.settings.pipelines.train.reinforce import (
+ REINFORCETrainExperimentSettings,
+)
+from turbo_alignment.trainers.online.reinforce import (
+ REINFORCETrainer,
+ REINFORCETrainingArguments,
+)
+
+logger = get_project_logger()
+
+
+class TrainREINFORCEStrategy(BaseTrainStrategy[REINFORCETrainExperimentSettings]):
+ @staticmethod
+ def _get_data_collator(
+ experiment_settings: REINFORCETrainExperimentSettings,
+ tokenizer: PreTrainedTokenizerBase,
+ **kwargs,
+ ) -> Callable:
+ return DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
+
+ @staticmethod
+ def _get_cherry_pick_callback(
+ experiment_settings: REINFORCETrainExperimentSettings,
+ tokenizer: PreTrainedTokenizerBase,
+ **kwargs,
+ ) -> ChatCherryPickCallback:
+ cherry_pick_settings = experiment_settings.cherry_pick_settings
+
+ cherry_pick_datasets = DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
+ cherry_pick_settings.dataset_settings,
+ tokenizer=tokenizer,
+ strategy=DatasetStrategy.INFERENCE,
+ )
+
+ metrics = [
+ Metric.by_name(metric.type)(MetricSettingsRegistry.by_name(metric.type)(**metric.parameters))
+ for metric in cherry_pick_settings.metric_settings
+ ]
+
+ return ChatCherryPickCallback(
+ cherry_pick_settings=cherry_pick_settings,
+ datasets=cherry_pick_datasets,
+ metrics=metrics,
+ )
+
+ @staticmethod
+ def _get_training_args(
+ experiment_settings: REINFORCETrainExperimentSettings,
+ ) -> REINFORCETrainingArguments:
+ return REINFORCETrainingArguments(
+ output_dir=str(experiment_settings.log_path / TRAINER_LOGS_FOLDER),
+ label_names=[],
+ remove_unused_columns=False,
+ **experiment_settings.trainer_settings.dict(),
+ )
+
+ @staticmethod
+ def _get_trainer(
+ training_args: REINFORCETrainingArguments,
+ experiment_settings: REINFORCETrainExperimentSettings,
+ model: PreTrainedModel,
+ tokenizer: PreTrainedTokenizerBase,
+ train_dataset: Dataset,
+ val_dataset: Dataset,
+ data_collator: Callable,
+ ):
+ # TODO: different tokenizer for reward model
+
+ ref_model = load_model(experiment_settings.model_settings, tokenizer)
+ for _, param in ref_model.named_parameters():
+ param.requires_grad = False
+
+ ref_model.eval()
+
+ reward_model = load_model(experiment_settings.reward_model_settings, tokenizer)
+ for _, param in reward_model.named_parameters():
+ param.requires_grad = False
+
+ reward_model.eval()
+
+ return REINFORCETrainer(
+ args=training_args,
+ tokenizer=tokenizer,
+ policy=model,
+ ref_model=ref_model,
+ train_dataset=train_dataset,
+ eval_dataset=val_dataset,
+ data_collator=data_collator,
+ reward_model=reward_model,
+ callbacks=[],
+ )
+
+ def _dataset_and_collator_sanity_check(self, dataset: Dataset, collator: DataCollatorMixin) -> None:
+ logger.info(f'Train sample input_ids:\n{dataset[0]}')
+ logger.info(f'Train sample example:\n{self.tokenizer.decode(dataset[0]["input_ids"])}')
+
+ def _get_datasets(self, experiment_settings: REINFORCETrainExperimentSettings) -> tuple[Dataset, Dataset]:
+ train_dataset: ConcatDataset = ConcatDataset(
+ datasets=DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
+ experiment_settings.train_dataset_settings,
+ tokenizer=self.tokenizer,
+ strategy=DatasetStrategy.INFERENCE,
+ )
+ )
+
+ val_dataset: ConcatDataset = ConcatDataset(
+ datasets=DatasetLoader[InferenceChatDataset](InferenceChatDataset).load_datasets(
+ experiment_settings.val_dataset_settings,
+ tokenizer=self.tokenizer,
+ strategy=DatasetStrategy.INFERENCE,
+ )
+ )
+ return train_dataset, val_dataset
diff --git a/turbo_alignment/settings/generators/chat.py b/turbo_alignment/settings/generators/chat.py
index 01bcc09..b672bec 100755
--- a/turbo_alignment/settings/generators/chat.py
+++ b/turbo_alignment/settings/generators/chat.py
@@ -4,4 +4,6 @@
class CustomChatGenerationSettings(ExtraFieldsNotAllowedBaseModel):
skip_special_tokens: bool = True
remove_prompt: bool = True
+ only_answer_logits: bool = True
batch: int = 1
+ return_logits: bool = False
diff --git a/turbo_alignment/settings/generators/outputs/chat.py b/turbo_alignment/settings/generators/outputs/chat.py
index 3d4c392..78719a1 100755
--- a/turbo_alignment/settings/generators/outputs/chat.py
+++ b/turbo_alignment/settings/generators/outputs/chat.py
@@ -1,6 +1,8 @@
+from typing import Any
+
import torch
-from turbo_alignment.dataset.chat import ChatDatasetRecord
+from turbo_alignment.dataset.chat import ChatMessage
from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel
from turbo_alignment.settings.generators.outputs.base import BaseInferenceOutput
@@ -8,17 +10,26 @@
class AnswerMessage(ExtraFieldsNotAllowedBaseModel):
id: str
content: str
+ answer_token_ids: torch.Tensor
+ answer_attention_mask: torch.Tensor
sequence_score: float | None = None
- input_token_ids: torch.Tensor | None = None
- answer_token_ids: torch.Tensor | None = None
logits: torch.Tensor | None = None
class Config:
arbitrary_types_allowed = True
-class ChatInferenceOutput(BaseInferenceOutput, ChatDatasetRecord):
+class ChatInferenceOutput(BaseInferenceOutput):
+ id: str | None = None
+ input_token_ids: torch.Tensor | None = None
+ input_attention_mask: torch.Tensor | None = None
answers: list[AnswerMessage]
+ messages: list[ChatMessage] | None
+ label: str | None = None
+ meta: dict[str, Any] | None = None
+
+ class Config:
+ arbitrary_types_allowed = True
class RagInferenceOutput(ChatInferenceOutput):
diff --git a/turbo_alignment/settings/generators/outputs/multimodal.py b/turbo_alignment/settings/generators/outputs/multimodal.py
index 6238397..69c21c3 100755
--- a/turbo_alignment/settings/generators/outputs/multimodal.py
+++ b/turbo_alignment/settings/generators/outputs/multimodal.py
@@ -1,5 +1,6 @@
from typing import Annotated
+import torch
from pydantic import Field
from turbo_alignment.dataset.base.models import DatasetRecord
@@ -18,3 +19,7 @@ class MultimodalInferenceOutput(BaseInferenceOutput, DatasetRecord):
]
label: str | None = None
answers: list[AnswerMessage]
+ input_token_ids: torch.Tensor | None = None
+
+ class Config:
+ arbitrary_types_allowed = True
diff --git a/turbo_alignment/settings/online.py b/turbo_alignment/settings/online.py
new file mode 100644
index 0000000..52d14c7
--- /dev/null
+++ b/turbo_alignment/settings/online.py
@@ -0,0 +1,15 @@
+from enum import Enum
+
+
+class LLMActorType(str, Enum):
+ LOCAL_TRANSFORMERS = 'local_transformers'
+ DISTRIBUTED_VLLM = 'distributed_vllm'
+
+
+class CriticType(str, Enum):
+ LOCAL_TRANSFORMERS = 'local_transformers'
+ DISTRIBUTED_VLLM = 'distributed_vllm'
+
+
+class RewardProcessorType(str, Enum):
+ RLOO = 'rloo'
diff --git a/turbo_alignment/settings/pipelines/__init__.py b/turbo_alignment/settings/pipelines/__init__.py
index 1b6143a..0a4d7c0 100755
--- a/turbo_alignment/settings/pipelines/__init__.py
+++ b/turbo_alignment/settings/pipelines/__init__.py
@@ -14,5 +14,8 @@
MultimodalTrainExperimentSettings,
)
from turbo_alignment.settings.pipelines.train.rag import RAGTrainExperimentSettings
+from turbo_alignment.settings.pipelines.train.reinforce import (
+ REINFORCETrainExperimentSettings,
+)
from turbo_alignment.settings.pipelines.train.rm import RMTrainExperimentSettings
from turbo_alignment.settings.pipelines.train.sft import SftTrainExperimentSettings
diff --git a/turbo_alignment/settings/pipelines/train/reinforce.py b/turbo_alignment/settings/pipelines/train/reinforce.py
new file mode 100644
index 0000000..be1c71a
--- /dev/null
+++ b/turbo_alignment/settings/pipelines/train/reinforce.py
@@ -0,0 +1,37 @@
+from turbo_alignment.settings.cherry_pick import ChatCherryPickSettings
+from turbo_alignment.settings.datasets import ChatMultiDatasetSettings
+from turbo_alignment.settings.model import (
+ PreTrainedAdaptersModelSettings,
+ PreTrainedModelSettings,
+)
+from turbo_alignment.settings.pipelines.train.base import BaseTrainExperimentSettings
+from turbo_alignment.settings.tf.trainer import TrainerSettings
+
+
+class REINFORCETrainerSettings(TrainerSettings):
+ max_tokens_count: int = 1024
+ stop_token: str = ''
+
+ penalty_reward_value: float = 0.1
+ clip_rewards_min: float = 0.1
+ clip_rewards_max: float = 1.0
+ kl_coef: float = 0.05
+ mean_baseline_coef: float = 0.1
+
+ num_generations: int = 3
+ num_samples_for_reward_stats: int = 0
+
+ non_eos_penalty: bool = True
+ temperature: float | None = None
+ whiten_rewards: bool = False
+
+
+class REINFORCETrainExperimentSettings(BaseTrainExperimentSettings):
+ train_dataset_settings: ChatMultiDatasetSettings
+ val_dataset_settings: ChatMultiDatasetSettings
+
+ cherry_pick_settings: ChatCherryPickSettings
+
+ reward_model_settings: PreTrainedModelSettings | PreTrainedAdaptersModelSettings
+
+ trainer_settings: REINFORCETrainerSettings
diff --git a/turbo_alignment/settings/tf/generation.py b/turbo_alignment/settings/tf/generation.py
index 14537d9..f9a5eae 100755
--- a/turbo_alignment/settings/tf/generation.py
+++ b/turbo_alignment/settings/tf/generation.py
@@ -3,7 +3,7 @@
class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel):
num_beams: int = 1
- max_new_tokens: int = 15
+ max_new_tokens: int | None = None
repetition_penalty: float = 1.0
num_return_sequences: int = 1
do_sample: bool = True
@@ -11,3 +11,4 @@ class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel):
top_k: int = 50
temperature: float = 1.0
stop_strings: str | list[str] = ''
+ max_length: int = 20
diff --git a/turbo_alignment/trainers/online/ray/vllm_engine.py b/turbo_alignment/trainers/online/ray/vllm_engine.py
new file mode 100644
index 0000000..b851ae1
--- /dev/null
+++ b/turbo_alignment/trainers/online/ray/vllm_engine.py
@@ -0,0 +1,118 @@
+import ray
+from ray.util.placement_group import placement_group
+from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+
+from turbo_alignment.common.logging import get_project_logger
+
+logger = get_project_logger()
+
+
+@ray.remote
+class LLMRayActor:
+ def __init__(self, *args, **kwargs):
+ import vllm
+
+ self.__version__ = vllm.__version__
+ assert self.__version__ >= '0.4.1', 'OpenRLHF only supports vLLM >= 0.4.1'
+
+ self.use_gpu_executor = kwargs['tensor_parallel_size'] == 1
+
+ # See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
+ if self.use_gpu_executor:
+ from turbo_alignment.trainers.online.ray.vllm_worker_wrap import WorkerWrap
+
+ vllm.worker.worker.Worker = WorkerWrap
+ else:
+ # RayGPUExecutor
+ # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
+ kwargs['worker_use_ray'] = True
+
+ if vllm.__version__ > '0.4.1':
+ RayWorkerWrapperPath = vllm.executor.ray_utils
+ else:
+ RayWorkerWrapperPath = vllm.engine.ray_utils
+
+ class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs['worker_module_name'] = 'openrlhf.trainer.ray.vllm_worker_wrap'
+ kwargs['worker_class_name'] = 'WorkerWrap'
+ super().__init__(*args, **kwargs)
+
+ RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper
+
+ self.llm = vllm.LLM(*args, **kwargs)
+
+ def generate(self, *args, **kwargs):
+ return self.llm.generate(*args, **kwargs)
+
+ def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend):
+ if self.use_gpu_executor:
+ return self.llm.llm_engine.model_executor.driver_worker.init_process_group(
+ master_address, master_port, rank_offset, world_size, group_name, backend
+ )
+ else:
+ return self.llm.llm_engine.model_executor._run_workers(
+ 'init_process_group', master_address, master_port, rank_offset, world_size, group_name, backend
+ )
+
+ def update_weight(self, name, dtype, shape, empty_cache=False):
+ self.stop_remote_worker_execution_loop()
+
+ if self.use_gpu_executor:
+ return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
+ else:
+ return self.llm.llm_engine.model_executor._run_workers('update_weight', name, dtype, shape, empty_cache)
+
+ def stop_remote_worker_execution_loop(self):
+ # Fix error for using 2 communication group
+ # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4
+ if self.__version__ > '0.4.2':
+ self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()
+
+
+def create_vllm_engines(
+ num_engines: int,
+ tensor_parallel_size: int,
+ pretrain: str,
+ seed: int,
+ enable_prefix_caching: bool,
+ max_model_len: int,
+):
+ vllm_engines = []
+ for i in range(num_engines):
+ # When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it.
+ num_gpus = int(tensor_parallel_size == 1)
+ scheduling_strategy = None
+
+ if tensor_parallel_size > 1:
+ bundles = [{'GPU': 1, 'CPU': 1}] * tensor_parallel_size
+ pg = placement_group(bundles)
+ ray.get(pg.ready())
+
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
+ placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
+ )
+
+ vllm_engines.append(
+ LLMRayActor.options(
+ num_cpus=1,
+ num_gpus=num_gpus,
+ scheduling_strategy=scheduling_strategy,
+ ).remote(
+ pretrain,
+ trust_remote_code=True,
+ tensor_parallel_size=tensor_parallel_size,
+ dtype='bfloat16',
+ seed=seed + i,
+ enable_prefix_caching=enable_prefix_caching,
+ max_model_len=max_model_len,
+ )
+ )
+
+ return vllm_engines
+
+
+if __name__ == '__main__':
+ llm = LLMRayActor.remote('meta-llama/Llama-2-7b-chat-hf', tensor_parallel_size=4)
+ output = ray.get(llm.generate.remote('San Franciso is a'))
+ print(f'output: {output}')
diff --git a/turbo_alignment/trainers/online/ray/vllm_worker_wrap.py b/turbo_alignment/trainers/online/ray/vllm_worker_wrap.py
new file mode 100644
index 0000000..c40a39d
--- /dev/null
+++ b/turbo_alignment/trainers/online/ray/vllm_worker_wrap.py
@@ -0,0 +1,43 @@
+import torch
+from vllm.worker.worker import Worker
+
+from turbo_alignment.common.distributed import init_process_group
+from turbo_alignment.common.logging import get_project_logger
+
+logger = get_project_logger()
+
+
+class WorkerWrap(Worker):
+ def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend='nccl'):
+ """Init torch process group for model weights update"""
+ assert torch.distributed.is_initialized(), 'default torch process group must be initialized'
+ assert group_name != '', 'group name must not be empty'
+
+ rank = torch.distributed.get_rank() + rank_offset
+ self._model_update_group = init_process_group(
+ backend=backend,
+ init_method=f'tcp://{master_address}:{master_port}',
+ world_size=world_size,
+ rank=rank,
+ group_name=group_name,
+ )
+ print(
+ f'init_process_group: master_address={master_address}, master_port={master_port}, ',
+ f'rank={rank}, world_size={world_size}, group_name={group_name}',
+ )
+
+ def update_weight(self, name, dtype, shape, empty_cache=False):
+ """Broadcast weight to all vllm workers from source rank 0 (actor model)"""
+ if torch.distributed.get_rank() == 0:
+ print(f'update weight: {name}, dtype: {dtype}, shape: {shape}')
+
+ assert dtype == self.model_config.dtype, f'mismatch dtype: src {dtype}, dst {self.model_config.dtype}'
+ weight = torch.empty(shape, dtype=dtype, device='cuda')
+ torch.distributed.broadcast(weight, 0, group=self._model_update_group)
+
+ self.model_runner.model.load_weights(weights=[(name, weight)])
+
+ del weight
+ # TODO: should we empty cache if all weights have updated?
+ # if empty_cache:
+ # torch.cuda.empty_cache()
diff --git a/turbo_alignment/trainers/online/reinforce.py b/turbo_alignment/trainers/online/reinforce.py
new file mode 100644
index 0000000..85c8e20
--- /dev/null
+++ b/turbo_alignment/trainers/online/reinforce.py
@@ -0,0 +1,388 @@
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.data
+from datasets import Dataset
+from transformers import (
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ TrainerCallback,
+ TrainingArguments, GenerationConfig,
+)
+
+from turbo_alignment.common.distributed import (
+ get_global_mean,
+ get_global_std,
+ get_log_mean_std,
+)
+from turbo_alignment.generators import ChatGenerator, RMSamplingGenerator
+from turbo_alignment.generators.base import ChatGeneratorBase
+from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings
+from turbo_alignment.settings.generators.outputs.chat import ChatInferenceOutput
+from turbo_alignment.settings.online import (
+ CriticType,
+ LLMActorType,
+ RewardProcessorType,
+)
+from turbo_alignment.settings.tf.generation import GeneratorTransformersSettings
+from turbo_alignment.trainers.multigpu import MultiGPUCherryPicksTrainer
+from turbo_alignment.trainers.online.reward_processor import RewardProcessor
+from turbo_alignment.trainers.utils import prepare_model
+
+
+# FIXME
+@dataclass
+class REINFORCETrainingArguments(TrainingArguments):
+ max_tokens_count: int = 1024
+ stop_token: str = ''
+
+ penalty_reward_value: float = 0.1
+ clip_rewards_min: float = 0.1
+ clip_rewards_max: float = 1.0
+ kl_coef: float = 0.05
+ mean_baseline_coef: float = 0.1
+
+ num_generations: int = 3
+ num_samples_for_reward_stats: int = 0
+
+ non_eos_penalty: bool = True
+ temperature: float | None = None
+ whiten_rewards: bool = False
+
+ actor_type: LLMActorType = LLMActorType.LOCAL_TRANSFORMERS
+ critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS
+
+ reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO
+
+
+class REINFORCETrainer(MultiGPUCherryPicksTrainer):
+ def __init__(
+ self,
+ args: REINFORCETrainingArguments,
+ tokenizer: PreTrainedTokenizerBase,
+ policy: PreTrainedModel | torch.nn.Module,
+ reward_model: PreTrainedModel | torch.nn.Module,
+ train_dataset: Dataset,
+ eval_dataset: Dataset,
+ ref_model: PreTrainedModel | torch.nn.Module | None = None,
+ callbacks: list[TrainerCallback] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ model=policy,
+ args=args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ tokenizer=tokenizer,
+ callbacks=callbacks,
+ **kwargs,
+ )
+
+ self.ref_model = ref_model
+ if self.ref_model is not None:
+ self.ref_model = prepare_model(self.ref_model, self.accelerator, self.is_deepspeed_enabled)
+
+ self.reward_model = reward_model
+
+ self._stored_metrics: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
+ self.kl_coef = args.kl_coef
+
+ self.gen_id = 1
+
+ # Take gradient accumulation into account
+ # In this way, the effective number of gradient
+ # updates taken into account stays the same
+ effective_num_previous_samples = 1 / (1 - self.args.mean_baseline_coef)
+ self.args.mean_baseline_coef = 1 - 1 / (effective_num_previous_samples * self.args.gradient_accumulation_steps)
+
+ self.stop_generation_token_id = tokenizer.encode(args.stop_token, add_special_tokens=False)
+ assert len(self.stop_generation_token_id) == 1, self.stop_generation_token_id
+
+ self.generator_transformers_settings = GeneratorTransformersSettings(
+ temperature=args.temperature,
+ max_length=args.max_tokens_count,
+ num_return_sequences=args.num_generations,
+ num_beams=args.num_generations,
+ do_sample=True,
+ )
+ self.generator_custom_settings = CustomChatGenerationSettings(
+ batch=self.args.per_device_train_batch_size,
+ remove_prompt=False,
+ only_answer_logits=True,
+ return_logits=True,
+ )
+ self.reward_processor = RewardProcessor.by_name(args.reward_processor_type)(
+ mean_baseline_coef=args.mean_baseline_coef,
+ num_generations=args.num_generations,
+ )
+
+ self.norm_reward_mean, self.norm_reward_std = self.reward_stats(
+ model=self.model, dataloader=self.get_train_dataloader()
+ )
+
+ # FIXME: some base class instead of RMSamplingGenerator (join with ChatGeneratorBase?)
+ def _get_rm_generator(self, reward_model: torch.nn.Module | PreTrainedModel) -> RMSamplingGenerator:
+ match self.args.critic_type:
+ case CriticType.LOCAL_TRANSFORMERS:
+ generator = RMSamplingGenerator(
+ model=reward_model,
+ tokenizer=self.tokenizer,
+ accelerator=self.accelerator,
+ )
+ case CriticType.DISTRIBUTED_VLLM:
+ generator = ...
+ case _:
+ raise ValueError(f'Critic {self.args.critic_type} is not supported')
+ return generator
+
+ def _get_chat_generator(self, model: torch.nn.Module | PreTrainedModel) -> ChatGeneratorBase:
+ match self.args.actor_type:
+ case LLMActorType.LOCAL_TRANSFORMERS:
+ generator = ChatGenerator(
+ model=model,
+ tokenizer=self.tokenizer,
+ transformers_settings=self.generator_transformers_settings,
+ custom_generation_settings=self.generator_custom_settings,
+ accelerator=self.accelerator,
+ )
+ case LLMActorType.DISTRIBUTED_VLLM:
+ generator = ...
+ case _:
+ raise ValueError(f'LLM Actor {self.args.actor_type} is not supported')
+ return generator
+
+ def get_answers_and_rewards(
+ self,
+ model: torch.nn.Module | PreTrainedModel,
+ inputs: dict[str, torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ generator = self._get_chat_generator(model)
+
+ generations: list[ChatInferenceOutput] = generator.generate_from_batch_records(
+ dataset_name='online', records_batch=inputs
+ )
+
+ response_ids = torch.stack([ans.answer_token_ids for g in generations for ans in g.answers])
+ response_attention_mask = torch.stack([ans.answer_attention_mask for g in generations for ans in g.answers])
+
+ response_tokens_mask = torch.zeros(response_ids.shape, dtype=torch.float32)
+ response_tokens_mask[:, generations[0].input_token_ids.shape[0] :] = 1.0
+
+ position_ids = (response_attention_mask.cumsum(-1) - 1).clamp(min=0)
+ position_ids.masked_fill_(response_attention_mask.to(torch.bool) == 0, 0)
+
+ rm_inputs = {
+ 'input_ids': response_ids,
+ 'attention_mask': response_attention_mask,
+ 'position_ids': position_ids,
+ }
+
+ critic_generator = self._get_rm_generator(self.reward_model)
+ rewards = critic_generator.generate_from_batch_records(rm_inputs)
+
+ return response_ids, response_attention_mask, response_tokens_mask, position_ids, rewards
+
+ def get_batch_loss_metrics(
+ self,
+ model: torch.nn.Module | PreTrainedModel,
+ inputs: dict[str, torch.Tensor],
+ train_eval: Literal['train', 'eval'] = 'train',
+ ):
+ with torch.no_grad():
+ (
+ query_response,
+ attention_mask,
+ response_tokens_mask,
+ position_ids,
+ rewards,
+ ) = self.get_answers_and_rewards(
+ model=self.accelerator.unwrap_model(model),
+ inputs=inputs,
+ )
+
+ ref_logprobs = self.get_logprobs(
+ model=self.ref_model,
+ input_ids=query_response,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ loss_mask=response_tokens_mask.to(torch.float32),
+ )
+
+ rewards, valid_mask, rewards_metrics = self.process_rewards(rewards=rewards, query_response=query_response)
+
+ logprobs = self.get_logprobs(
+ model=model,
+ input_ids=query_response,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ loss_mask=response_tokens_mask,
+ )
+
+ with torch.no_grad():
+ kl_term = logprobs.detach() - ref_logprobs
+ regularized_rewards = rewards - self.kl_coef * kl_term
+
+ baselined_reward, baseline_metrics = self.reward_processor.baseline_rewards(rewards=regularized_rewards)
+
+ loss = -baselined_reward * logprobs
+
+ assert len(loss.shape) == 1, loss.shape
+
+ with torch.no_grad():
+ metrics = {}
+ for tensor, name in zip(
+ [
+ rewards,
+ regularized_rewards,
+ baselined_reward,
+ loss.detach(),
+ kl_term.detach(),
+ logprobs.detach(),
+ 1 - valid_mask.float(),
+ ],
+ [
+ 'rewards',
+ 'regularized_rewards',
+ 'baselined_rewards',
+ 'loss',
+ 'kl_term',
+ 'logprobs',
+ 'invalid_rewards',
+ ],
+ ):
+ metrics = {
+ **metrics,
+ **get_log_mean_std(tensor, name, train_eval),
+ }
+ metrics[f'{train_eval}/kl_coef'] = self.kl_coef
+
+ for k, v in rewards_metrics.items():
+ metrics[f'{train_eval}/{k}'] = v
+ for k, v in baseline_metrics.items():
+ metrics[f'{train_eval}/{k}'] = v
+
+ return loss.mean(), metrics
+
+ def compute_loss(self, model, inputs, return_outputs: bool = False):
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, 'train')
+
+ self.store_metrics(metrics=metrics, train_eval='train')
+
+ return (loss, metrics) if return_outputs else loss
+
+ def prediction_step(
+ self,
+ model: Union[PreTrainedModel, torch.nn.Module],
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ with torch.no_grad():
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, 'eval')
+
+ self.store_metrics(metrics, train_eval='eval')
+
+ return loss.detach(), None, None
+
+ @torch.no_grad()
+ def reward_stats(
+ self, model: torch.nn.Module | PreTrainedModel, dataloader: torch.utils.data.DataLoader
+ ) -> tuple[float, float]:
+ if self.args.num_samples_for_reward_stats == 0:
+ norm_reward_mean = 0
+ norm_reward_std = 1
+ else:
+ all_rewards = []
+ samples_processed = 0
+ for batch in dataloader:
+ _, _, _, _, rewards = self.get_answers_and_rewards(
+ model=self.accelerator.unwrap_model(model),
+ inputs=batch,
+ )
+ rewards, _ = self.reward_processor.postprocess_rewards(rewards=rewards)
+
+ all_rewards.append(rewards)
+ samples_processed += len(batch['input_ids']) * self.accelerator.num_processes
+
+ if samples_processed > self.args.num_samples_for_reward_stats:
+ break
+
+ all_rewards = torch.cat(all_rewards, 0)
+
+ norm_reward_mean = get_global_mean(all_rewards)
+ norm_reward_std = get_global_std(all_rewards, mean=norm_reward_mean)
+
+ return norm_reward_mean, norm_reward_std
+
+ def get_logprobs(
+ self,
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_ids: torch.Tensor,
+ loss_mask: torch.Tensor,
+ ):
+ logits = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ use_cache=False,
+ ).logits[:, :-1]
+ logits /= self.args.temperature
+
+ all_logprob = F.log_softmax(logits, dim=-1)
+
+ logprob = torch.gather(all_logprob, 2, input_ids[:, 1:].unsqueeze(-1)).squeeze(-1)
+ logprob[~loss_mask[:, 1:].to(torch.bool)] = 0
+
+ return logprob.sum(-1)
+
+ def fill_nonvalid_rewards(self, rewards, query_response) -> Tuple[torch.Tensor, torch.Tensor]:
+ if self.args.non_eos_penalty:
+ invalid_mask = query_response[:, -1] != self.stop_generation_token_id[0]
+ rewards[invalid_mask] = self.args.penalty_reward_value
+
+ return rewards, ~invalid_mask
+
+ return rewards, torch.ones_like(rewards).to(torch.bool)
+
+ def process_rewards(self, rewards, query_response) -> tuple[torch.Tensor, torch.Tensor, Any]:
+ """
+ Second, this token should be the .
+ Otherwise, return a fixed reward of -1.
+ """
+
+ rewards, reward_metrics = self.reward_processor.postprocess_rewards(rewards=rewards)
+ rewards = rewards.squeeze(-1)
+ reward_metrics['normalizing_reward_mean'] = self.norm_reward_mean
+ reward_metrics['normalizing_reward_std'] = self.norm_reward_std
+ rewards = (rewards - self.norm_reward_mean) / self.norm_reward_std
+ # boundness is required: https://arxiv.org/pdf/2406.01462
+ rewards = torch.clamp(rewards, self.args.clip_rewards_min, self.args.clip_rewards_max)
+
+ rewards, valid_mask = self.fill_nonvalid_rewards(rewards=rewards, query_response=query_response)
+
+ return rewards, valid_mask, reward_metrics
+
+ def store_metrics(self, metrics: Dict[str, float], train_eval: Literal['train', 'eval'] = 'train') -> None:
+ for key, value in metrics.items():
+ self._stored_metrics[train_eval][key].append(value)
+
+ def log(self, logs: Dict[str, float]) -> None:
+ main_keys = list(self._stored_metrics.keys())
+ for main_key in main_keys:
+ for key, metrics in self._stored_metrics[main_key].items():
+ logs[key] = torch.tensor(metrics).float().cpu().mean().item()
+
+ if main_key == 'train':
+ logs['train/global_step'] = int(self.state.global_step)
+ else:
+ logs['eval/global_step'] = int(self.state.global_step // self.args.eval_steps)
+
+ del self._stored_metrics[main_key]
+
+ return super().log(logs) # pylint: disable=no-member
diff --git a/turbo_alignment/trainers/online/reward_actor.py b/turbo_alignment/trainers/online/reward_actor.py
new file mode 100644
index 0000000..e69de29
diff --git a/turbo_alignment/trainers/online/reward_processor.py b/turbo_alignment/trainers/online/reward_processor.py
new file mode 100644
index 0000000..50d126c
--- /dev/null
+++ b/turbo_alignment/trainers/online/reward_processor.py
@@ -0,0 +1,67 @@
+from abc import ABC
+
+import torch
+from allenai_common import Registrable
+
+from turbo_alignment.common.distributed import get_global_mean, get_log_mean_std
+from turbo_alignment.settings.online import RewardProcessorType
+
+
+class RewardProcessor(ABC, Registrable):
+ def __init__(
+ self,
+ mean_baseline_coef: float,
+ num_generations: int,
+ ) -> None:
+ self.mean_reward: float | None = None
+ self.mean_baseline_coef = mean_baseline_coef
+ self.num_generations = num_generations
+
+ def postprocess_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
+ raise NotImplementedError
+
+ def baseline_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
+ baseline: float = self.mean_reward if self.mean_reward is not None else 0
+ advantages: torch.Tensor = rewards - baseline
+
+ with torch.no_grad():
+ metrics: dict[str, float | int] = {'baseline_mean': baseline}
+
+ self.update_mean_reward(rewards=rewards)
+
+ return advantages, metrics
+
+ @torch.no_grad()
+ def update_mean_reward(self, rewards: torch.Tensor):
+ global_mean_reward: float = get_global_mean(rewards)
+
+ if self.mean_reward is None:
+ self.mean_reward = global_mean_reward
+ else:
+ self.mean_reward = (
+ self.mean_baseline_coef * self.mean_reward + (1 - self.mean_baseline_coef) * global_mean_reward
+ )
+
+
+@RewardProcessor.register(RewardProcessorType.RLOO)
+class RLOORewardProcessor(RewardProcessor):
+ def postprocess_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
+ rewards = rewards[:, 0].unsqueeze(-1) # values are at 1
+
+ with torch.no_grad():
+ metrics: dict[str, float] = get_log_mean_std(rewards, 'real_reward')
+
+ return rewards, metrics
+
+ def baseline_rewards(self, rewards: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
+ rewards = rewards.reshape(-1, self.num_generations)
+ baseline: torch.Tensor = (rewards.sum(-1).unsqueeze(-1) - rewards) / (self.num_generations - 1)
+ rloo_advantages: torch.Tensor = (rewards - baseline).flatten()
+
+ with torch.no_grad():
+ metrics: dict[str, float] = {
+ **get_log_mean_std(baseline, 'baseline'),
+ **get_log_mean_std(baseline.std(-1), 'baseline_inner_std'),
+ }
+
+ return rloo_advantages, metrics