diff --git a/Makefile b/Makefile
index 8fa7f59..94724e2 100644
--- a/Makefile
+++ b/Makefile
@@ -39,6 +39,9 @@ ifneq ($(CODE),)
unify --in-place --recursive $(CODE) tutorials
endif
+bash-dev:
+ docker run -v $(shell pwd):/app --rm -it ${DEV_IMAGE} bash
+
lock:
rm poetry.lock
poetry lock
diff --git a/README.md b/README.md
index 0ce8668..5ed343f 100755
--- a/README.md
+++ b/README.md
@@ -56,7 +56,6 @@ Turbo-Alignment supports a wide range of methods for model training and alignmen
- **📏** Length
- **🌀** Perplexity
- **🌟** METEOR
-- **📐** RAGAS
- **🔍** Retrieval Utility
@@ -86,7 +85,7 @@ Examples of datasets are available [here](docs/dataset_example.md).
- [RSO](#-RSO-sampling)
- [Common](#-common)
- [Preprocess](#-preprocess-common)
- - [Convert to base](#-convert-to-base-common)
+ - [Merge adapters to base](#-merge-adapters-to-base-common)
# Train
@@ -170,7 +169,9 @@ To launch RAG:
## 🚀 Installation
### 📦 Python Package
-⌛️ in progress..
+```bash
+pip install turbo-alignment
+```
### 🛠️ From Source
For the latest features before an official release:
diff --git a/docs/dataset_example.md b/docs/dataset_example.md
index 7ca2f81..6ea91c5 100644
--- a/docs/dataset_example.md
+++ b/docs/dataset_example.md
@@ -11,7 +11,7 @@
- [Pair Preferences Dataset](#-pair-preferences-dataset)
- [KTO Dataset](#-kto-dataset)
- [Sampling Dataset](#-sampling-dataset)
-- [Multimodal Dataset ](#-multimodal-dataset) (⌛️ Work in progress...)
+- [Multimodal Dataset ](#-multimodal-dataset)
- [Classification Dataset](#-classification-dataset)
- [DPPO Dataset](#-ddpo-dataset) (⌛️ Work in progress...)
@@ -118,9 +118,45 @@ Example:
## Multimodal Dataset
-⌛️ in progress..
+- `messages`: `list[MultimodalChatMessage]` — This is a sequence of messages that make up the chat history. Each `ChatMessage` includes:
+ - `role` - The participant's role in the conversation (e.g., `user` or `bot`).
+ - `type` – The type of modality (e.g., `text` or `image`)
+ - `content` - If the `type` is `text`, it's the textual content of the message. If it's `image`, it's the file path.
+Example:
+```json
+{
+ "id": "0",
+ "messages": [
+ {
+ "role": "system",
+ "type": "text",
+ "content": "You are a Multimodal AI assistant."
+ },
+ {
+ "role": "user",
+ "type": "image",
+ "content": "/path/to/cat.jpg"
+ },
+ {
+ "role": "user",
+ "type": "image",
+ "content": "/path/to/dog.jpg"
+ },
+ {
+ "role": "user",
+ "type": "text",
+ "content": "What's the difference between these two images?"
+ },
+ {
+ "role": "bot",
+ "type": "text",
+ "content": "The two images in question both feature animals, albeit of different species. The first image depicts a dog, which is generally perceived as an animal that elicits positive emotional responses. The second image features a cat, which is also regarded as an animal that evokes a positive emotional response."
+ }
+ ]
+}
+```
diff --git a/pyproject.toml b/pyproject.toml
index 18991d7..c8a6871 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ name = "turbo-alignment"
packages = [
{ include = "turbo_alignment" },
]
-version = "0.1.0"
+version = "0.0.2"
description = "turbo-alignment repository"
authors = ["T Mega Alignment Team " ]
@@ -44,7 +44,6 @@ pydantic ="^2.7.0"
timm ="^0.9.7"
opencv-python = "^4.10.0.84"
langchain-huggingface = "^0.0.3"
-ragas = "^0.1.10"
[tool.poetry.group.deepspeed.dependencies]
accelerate = "0.27"
diff --git a/tests/cli/test_dpo_train.py b/tests/cli/test_dpo_train.py
index 08e1f76..23bd4af 100755
--- a/tests/cli/test_dpo_train.py
+++ b/tests/cli/test_dpo_train.py
@@ -12,7 +12,10 @@
@pytest.mark.parametrize(
'config_path',
- [FIXTURES_PATH / 'configs/train/dpo/base.json'],
+ [
+ FIXTURES_PATH / 'configs/train/dpo/base.json',
+ FIXTURES_PATH / 'configs/train/dpo/simpo.json',
+ ],
)
def test_dpo_train(config_path: Path):
result = runner.invoke(
diff --git a/tests/cli/test_multimodal_inference.py b/tests/cli/test_multimodal_inference.py
index 91d1b0e..60c9a05 100644
--- a/tests/cli/test_multimodal_inference.py
+++ b/tests/cli/test_multimodal_inference.py
@@ -1,26 +1,26 @@
-# from pathlib import Path
+from pathlib import Path
-# import pytest
-# from typer.testing import CliRunner
+import pytest
+from typer.testing import CliRunner
-# from tests.constants import FIXTURES_PATH
-# from turbo_alignment.cli import app
-# from turbo_alignment.settings.pipelines.inference.multimodal import (
-# MultimodalInferenceExperimentSettings,
-# )
+from tests.constants import FIXTURES_PATH
+from turbo_alignment.cli import app
+from turbo_alignment.settings.pipelines.inference.multimodal import (
+ MultimodalInferenceExperimentSettings,
+)
-# runner = CliRunner()
+runner = CliRunner()
-# @pytest.mark.parametrize(
-# 'config_path',
-# [
-# FIXTURES_PATH / 'configs/inference/multimodal/llama_llava_clip_pickle.json',
-# ],
-# )
-# def test_multimodal_inference_mlp_with_preprocessing(config_path: Path):
-# result = runner.invoke(
-# app, ['inference_multimodal', '--inference_settings_path', str(config_path)], catch_exceptions=False
-# )
-# assert result.exit_code == 0
-# assert MultimodalInferenceExperimentSettings.parse_file(config_path).save_path.is_dir()
+@pytest.mark.parametrize(
+ 'config_path',
+ [
+ FIXTURES_PATH / 'configs/inference/multimodal/llama_llava_clip_pickle.json',
+ ],
+)
+def test_multimodal_inference_mlp_with_preprocessing(config_path: Path):
+ result = runner.invoke(
+ app, ['inference_multimodal', '--inference_settings_path', str(config_path)], catch_exceptions=False
+ )
+ assert result.exit_code == 0
+ assert MultimodalInferenceExperimentSettings.parse_file(config_path).save_path.is_dir()
diff --git a/tests/cli/test_sft.py b/tests/cli/test_sft.py
index f66941f..ab01abd 100755
--- a/tests/cli/test_sft.py
+++ b/tests/cli/test_sft.py
@@ -14,24 +14,10 @@
'config_path',
[
FIXTURES_PATH / 'configs/train/sft/base.json',
+ FIXTURES_PATH / 'configs/train/sft/sft_with_rm_metric.json',
],
)
def test_sft_train(config_path: Path):
result = runner.invoke(app, ['train_sft', '--experiment_settings_path', str(config_path)], catch_exceptions=False)
assert result.exit_code == 0
assert SftTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
-
-
-@pytest.mark.parametrize(
- 'config_path',
- [
- FIXTURES_PATH / 'configs/train/sft/resume_from_checkpoint.json',
- ],
-)
-def test_sft_from_checkpoint(config_path: Path):
- result = runner.invoke(
- app,
- ['train_sft', '--experiment_settings_path', str(config_path)],
- )
- assert result.exit_code == 0
- assert SftTrainExperimentSettings.parse_file(config_path).log_path.is_dir()
diff --git a/tests/fixtures/configs/inference/rag/base.json b/tests/fixtures/configs/inference/rag/base.json
index 2282b7b..dff8d56 100755
--- a/tests/fixtures/configs/inference/rag/base.json
+++ b/tests/fixtures/configs/inference/rag/base.json
@@ -41,10 +41,10 @@
"num_beams": 1,
"max_new_tokens": 10,
"repetition_penalty": 1.2,
+ "stop_strings": "",
"do_sample": false
},
"custom_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false,
"remove_prompt": false
}
diff --git a/tests/fixtures/configs/train/dpo/base.json b/tests/fixtures/configs/train/dpo/base.json
index 8df1392..a06caad 100755
--- a/tests/fixtures/configs/train/dpo/base.json
+++ b/tests/fixtures/configs/train/dpo/base.json
@@ -56,10 +56,10 @@
"generator_transformers_settings": {
"num_beams": 1,
"do_sample": false,
+ "stop_strings": "",
"max_new_tokens": 8
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/dpo/simpo.json b/tests/fixtures/configs/train/dpo/simpo.json
new file mode 100755
index 0000000..593463d
--- /dev/null
+++ b/tests/fixtures/configs/train/dpo/simpo.json
@@ -0,0 +1,134 @@
+{
+ "train_dataset_settings": {
+ "sources": [
+ {
+ "name": "rm_preferences_test",
+ "records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl",
+ "sample_rate": 1
+ }
+ ],
+ "chat_settings":{
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "",
+ "user": "",
+ "system": ""
+ },
+ "prefix_template": "{role}",
+ "suffix_template": ""
+ },
+ "max_tokens_count": 120
+ },
+ "add_labels": true,
+ "dataset_type": "pair_preferences"
+ },
+ "val_dataset_settings": {
+ "sources": [
+ {
+ "name": "rm_preferences_test",
+ "records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl",
+ "sample_rate": 1
+ }
+ ],
+ "chat_settings":{
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "",
+ "user": "",
+ "system": ""
+ },
+ "prefix_template": "{role}",
+ "suffix_template": ""
+ },
+ "max_tokens_count": 120
+ },
+ "add_labels": true,
+ "dataset_type": "pair_preferences"
+ },
+ "model_settings": {
+ "model_path": "tests/fixtures/models/llama2_tiny",
+ "model_type": "causal",
+ "transformers_settings": {},
+ "adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer",
+ "is_trainable": true
+ },
+ "cherry_pick_settings": {
+ "generator_transformers_settings": {
+ "num_beams": 1,
+ "do_sample": false,
+ "stop_strings": "",
+ "max_new_tokens": 8
+ },
+ "custom_generation_settings": {
+ "skip_special_tokens": false
+ },
+ "dataset_settings": {
+ "sources": [
+ {
+ "name": "chat_test",
+ "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl",
+ "num_samples": 2
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "",
+ "user": "",
+ "system": ""
+ },
+ "prefix_template": "{role}",
+ "suffix_template": ""
+ },
+ "dataset_type": "chat",
+ "max_tokens_count": 150,
+ "only_answer_loss": true
+ },
+ "metric_settings": [
+ {
+ "type": "length",
+ "parameters": {"need_average": [true]}
+ },
+ {
+ "type": "kl",
+ "parameters": {
+ "need_average": [true],
+ "ref_logits_type": "sft"
+ }
+ }
+ ]
+ },
+ "tokenizer_settings": {},
+ "trainer_settings": {
+ "evaluation_strategy": "steps",
+ "per_device_train_batch_size": 2,
+ "per_device_eval_batch_size": 2,
+ "gradient_accumulation_steps": 2,
+ "eval_steps": 4,
+ "save_steps": 4,
+ "logging_steps": 1,
+ "learning_rate": 0.0003,
+ "num_train_epochs": 2,
+ "lr_scheduler_type": "cosine",
+ "warmup_steps": 2,
+ "fp16": false,
+ "bf16": false,
+ "optim": "adamw_torch",
+ "save_total_limit": 1,
+ "average_log_prob": true,
+ "loss_settings": {
+ "loss_type": "simpo"
+ },
+ "sync_ref_settings": {
+ "sync_ref_model": false
+ },
+ "use_ref_model": false,
+ "use_sft_model": true,
+ "no_cuda": true
+ },
+ "wandb_settings": {
+ "project_name": "alignment",
+ "run_name": "dpo",
+ "entity": "turbo-alignment"
+ },
+ "log_path": "test_dpo_llama_train_output"
+}
diff --git a/tests/fixtures/configs/train/kto/base.json b/tests/fixtures/configs/train/kto/base.json
index 30d4192..28dc7a4 100755
--- a/tests/fixtures/configs/train/kto/base.json
+++ b/tests/fixtures/configs/train/kto/base.json
@@ -54,10 +54,10 @@
"generator_transformers_settings": {
"num_beams": 1,
"do_sample": false,
+ "stop_strings": "",
"max_new_tokens": 8
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json
index f140504..0195034 100644
--- a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json
+++ b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json
@@ -31,7 +31,7 @@
"start_modality_token": "",
"end_modality_token": "",
"dataset_type": "multimodal",
- "max_tokens_count": 2000,
+ "max_tokens_count": 300,
"only_answer_loss": true,
"truncate_top": false
},
@@ -67,7 +67,7 @@
"start_modality_token": "",
"end_modality_token": "",
"dataset_type": "multimodal",
- "max_tokens_count": 2000,
+ "max_tokens_count": 300,
"only_answer_loss": true,
"truncate_top": false
},
@@ -144,12 +144,12 @@
"cherry_pick_settings": {
"generator_transformers_settings": {
"num_beams": 1,
- "max_new_tokens": 128,
+ "max_new_tokens": 16,
"repetition_penalty": 1.1,
+ "stop_strings": "",
"do_sample": true
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": true
},
"dataset_settings": {
@@ -170,7 +170,7 @@
"suffix_template": ""
},
"dataset_type": "multimodal",
- "max_tokens_count": 2000,
+ "max_tokens_count": 300,
"n_modality_embeddings": 225,
"start_modality_token": "",
"end_modality_token": "",
diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json
index 1b51ab2..98fbad8 100644
--- a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json
+++ b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json
@@ -146,10 +146,10 @@
"num_beams": 1,
"max_new_tokens": 128,
"repetition_penalty": 1.1,
+ "stop_strings": "",
"do_sample": true
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": true
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json
index 9f23e43..d7a46e7 100644
--- a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json
+++ b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json
@@ -146,10 +146,10 @@
"num_beams": 1,
"max_new_tokens": 4,
"repetition_penalty": 1.1,
+ "stop_strings": "",
"do_sample": false
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": true
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/rag/base.json b/tests/fixtures/configs/train/rag/base.json
index b003986..f11b4b8 100755
--- a/tests/fixtures/configs/train/rag/base.json
+++ b/tests/fixtures/configs/train/rag/base.json
@@ -87,10 +87,10 @@
"num_beams": 3,
"max_new_tokens": 16,
"repetition_penalty": 1.1,
+ "stop_strings": "",
"do_sample": true
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/sft/base.json b/tests/fixtures/configs/train/sft/base.json
index 8e0b4d4..b5dc9fd 100755
--- a/tests/fixtures/configs/train/sft/base.json
+++ b/tests/fixtures/configs/train/sft/base.json
@@ -69,10 +69,10 @@
"cherry_pick_settings": {
"generator_transformers_settings": {
"num_beams": 3,
+ "stop_strings": ["", ""],
"max_new_tokens": 8
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/sft/prompt_tuning.json b/tests/fixtures/configs/train/sft/prompt_tuning.json
index 0a09dfc..c187176 100755
--- a/tests/fixtures/configs/train/sft/prompt_tuning.json
+++ b/tests/fixtures/configs/train/sft/prompt_tuning.json
@@ -63,10 +63,10 @@
"num_beams": 1,
"max_new_tokens": 35,
"repetition_penalty": 1.1,
+ "stop_strings": "",
"do_sample": true
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false
},
"dataset_settings": {
diff --git a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json
index 2a4515c..c281971 100755
--- a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json
+++ b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json
@@ -70,10 +70,10 @@
"generator_transformers_settings": {
"num_beams": 1,
"num_return_sequences": 2,
+ "stop_strings": "",
"max_new_tokens": 8
},
"custom_generation_settings": {
- "generation_eos_token": "",
"skip_special_tokens": false
},
"dataset_settings": {
diff --git a/tests/fixtures/datasets/multimodal/image_chat.jsonl b/tests/fixtures/datasets/multimodal/image_chat.jsonl
index b621ea1..e53f0b5 100644
--- a/tests/fixtures/datasets/multimodal/image_chat.jsonl
+++ b/tests/fixtures/datasets/multimodal/image_chat.jsonl
@@ -1,5 +1,4 @@
{"id": "0", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_1.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "Sorry, I will not describe the scene."}]}
{"id": "1", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_2.jpg"}, {"role": "user", "type": "text", "content": "What do you see on the image?"}, {"role": "bot", "type": "text", "content": "I see nothing."}, {"role": "user", "type": "text", "content": "What about this one?"}, {"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_1.jpg"}, {"role": "bot", "type": "text", "content": "Sorry..."}]}
{"id": "2", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_3.jpg"}, {"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_4.jpg"}, {"role": "user", "type": "text", "content": "Please, describe these two photos."}, {"role": "bot", "type": "text", "content": "OK."}]}
-{"id": "3", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_5.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "No."}]}
-{"id": "4", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/does_not_exist.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "No."}]}
\ No newline at end of file
+{"id": "3", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_5.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "No."}]}
\ No newline at end of file
diff --git a/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json b/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json
index 0d0a47e..2522058 100644
--- a/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json
+++ b/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json
@@ -22,8 +22,8 @@
"rank_pattern": {},
"revision": null,
"target_modules": [
- "q_proj",
- "k_proj"
+ "k_proj",
+ "q_proj"
],
"task_type": "CAUSAL_LM",
"use_rslora": false
diff --git a/turbo_alignment/cli/common.py b/turbo_alignment/cli/common.py
index 9150e86..d823b42 100644
--- a/turbo_alignment/cli/common.py
+++ b/turbo_alignment/cli/common.py
@@ -4,7 +4,7 @@
from turbo_alignment import pipelines
from turbo_alignment.cli.app import app
-from turbo_alignment.common.tf.convert_to_base_model import peft_to_base_model
+from turbo_alignment.common.tf.merge_adapters_to_base import peft_to_base_model
from turbo_alignment.settings.datasets.multimodal import (
MultimodalDatasetProcessingSettings,
)
diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py
index 60afa6e..fc4d9b4 100644
--- a/turbo_alignment/common/data/multimodal/common.py
+++ b/turbo_alignment/common/data/multimodal/common.py
@@ -25,5 +25,4 @@ def read(self, path: str) -> torch.Tensor:
safetensors_file = self._get_safetensors_file(Path(path).parent)
if self.processed_tensors is None:
self.processed_tensors = safe_open(safetensors_file, framework='pt', device='cpu')
-
return self.processed_tensors.get_tensor(Path(path).name)
diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py
index 8a572bd..73736b7 100755
--- a/turbo_alignment/common/tf/loaders/model/model.py
+++ b/turbo_alignment/common/tf/loaders/model/model.py
@@ -29,7 +29,7 @@ def _load_pretrained_adapters(
) -> PeftModel:
return PeftModel.from_pretrained(
model,
- model_settings.adapter_path, # type: ignore
+ model_settings.adapter_path,
is_trainable=model_settings.is_trainable,
)
@@ -60,7 +60,16 @@ def load_model(
for new_token, old_token in model_settings.embeddings_initialization_strategy.items():
new_token_id = tokenizer.get_added_vocab()[new_token]
old_token_id = tokenizer.encode(old_token, add_special_tokens=False)[0]
- model.model.embed_tokens.weight[new_token_id, :] = model.model.embed_tokens.weight[old_token_id, :]
+
+ if model.config.model_type == 'gpt_neox':
+ model.gpt_neox.embed_in.weight[new_token_id, :] = torch.clone(
+ model.gpt_neox.embed_in.weight[old_token_id, :]
+ )
+ if model_settings.model_type == 'causal':
+ model.embed_out.weight[new_token_id, :] = torch.clone(model.embed_out.weight[old_token_id, :])
+
+ elif model.config.model_type == 'llama':
+ model.model.embed_tokens.weight[new_token_id, :] = model.model.embed_tokens.weight[old_token_id, :]
if isinstance(model_settings, PreTrainedAdaptersModelSettings):
model = _load_pretrained_adapters(model, model_settings)
diff --git a/turbo_alignment/common/tf/convert_to_base_model.py b/turbo_alignment/common/tf/merge_adapters_to_base.py
similarity index 100%
rename from turbo_alignment/common/tf/convert_to_base_model.py
rename to turbo_alignment/common/tf/merge_adapters_to_base.py
diff --git a/turbo_alignment/common/tf/stopping_criteria.py b/turbo_alignment/common/tf/stopping_criteria.py
deleted file mode 100755
index df8314c..0000000
--- a/turbo_alignment/common/tf/stopping_criteria.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import torch
-from transformers import PreTrainedTokenizerBase, StoppingCriteria
-
-
-class EndTagCriteria(StoppingCriteria):
- def __init__(self, end_tag: str, tokenizer: PreTrainedTokenizerBase) -> None:
- super().__init__()
- self.tokenizer = tokenizer
- self.stop_tag = end_tag
-
- def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
- assert input_ids.shape[0] == 1
- return self.tokenizer.decode(input_ids[0], skip_special_tokens=False)[-len(self.stop_tag) :] == self.stop_tag
diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py
index fb13633..54fcde2 100644
--- a/turbo_alignment/dataset/multimodal/multimodal.py
+++ b/turbo_alignment/dataset/multimodal/multimodal.py
@@ -5,7 +5,7 @@
import numpy as np
import torch
from allenai_common import Params
-from safetensors import SafetensorError # type: ignore[attr-defined]
+from safetensors._safetensors_rust import SafetensorError
from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.data.multimodal import BaseModalityReader
@@ -152,9 +152,7 @@ def __init__(self, tokenizer, source, settings) -> None:
super().__init__(tokenizer=tokenizer, source=source, settings=settings)
- self._chat_dataset = TrainChatDataset(
- tokenizer=tokenizer, source=source, settings=settings, read=False
- ) # type: ignore[misc]
+ self._chat_dataset = TrainChatDataset(tokenizer=tokenizer, source=source, settings=settings, read=False)
self._read()
@@ -192,6 +190,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
)
tokenized_record['labels'][modality_tokens_mask] = DISABLE_LOSS_LABEL
+
outputs.append(
{
**tokenized_record,
@@ -212,8 +211,11 @@ def __init__(
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
-
- self._chat_dataset = InferenceChatDataset(*args, random_cut=random_cut, **kwargs, read=False) # type: ignore
+ settings = kwargs['settings']
+ settings.random_cut = random_cut
+ self._chat_dataset = InferenceChatDataset(
+ tokenizer=kwargs['tokenizer'], source=kwargs['source'], settings=settings, read=False
+ )
self._read()
def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]:
diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py
index 244d369..8a9d36e 100755
--- a/turbo_alignment/generators/base.py
+++ b/turbo_alignment/generators/base.py
@@ -4,14 +4,8 @@
import torch
from accelerate import Accelerator
-from transformers import (
- GenerationConfig,
- PreTrainedModel,
- PreTrainedTokenizerBase,
- StoppingCriteriaList,
-)
-
-from turbo_alignment.common.tf.stopping_criteria import EndTagCriteria
+from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
+
from turbo_alignment.dataset.base import BaseDataset
from turbo_alignment.dataset.base.models import DatasetRecord
from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings
@@ -105,27 +99,8 @@ def __init__(
self._return_logits = return_logits
- eos_token_id: list[int] = self._tokenizer.encode(
- custom_generation_settings.generation_eos_token, add_special_tokens=False
- )
-
- self._stopping_criteria: StoppingCriteriaList | None = None
- if len(eos_token_id) != 1:
- eos_token_id = []
- if transformers_settings.num_beams > 1 or transformers_settings.do_sample:
- raise ValueError('You should use only 1 eos token with num_beams > 1 or do_sample=True')
- self._stopping_criteria = StoppingCriteriaList(
- [
- EndTagCriteria(
- custom_generation_settings.generation_eos_token,
- tokenizer=tokenizer,
- )
- ]
- )
-
self._transformers_generator_parameters = GenerationConfig(
bos_token_id=self._tokenizer.bos_token_id,
- eos_token_id=eos_token_id,
**transformers_settings.dict(),
)
@@ -170,8 +145,8 @@ def _generate_from_batch(
@staticmethod
def _postprocess(input_indices: torch.Tensor, output_indices: torch.Tensor, remove_prompt: bool) -> torch.Tensor:
if remove_prompt:
- return output_indices[:, input_indices.shape[1] :]
- return output_indices
+ return output_indices[:, input_indices.shape[1] :].cpu()
+ return output_indices.cpu()
def _decode(self, token_indices: torch.Tensor) -> list[str]:
return self._tokenizer.batch_decode(
diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py
index 79798b1..acc4337 100755
--- a/turbo_alignment/generators/chat.py
+++ b/turbo_alignment/generators/chat.py
@@ -39,8 +39,8 @@ def _generate_from_batch_records(
inputs=batched_input_ids,
attention_mask=batched_attention_mask,
generation_config=self._transformers_generator_parameters,
+ tokenizer=self._tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
- stopping_criteria=self._stopping_criteria,
)
postprocessed_output_indices = self._postprocess(
@@ -84,8 +84,8 @@ def _generate_from_single_record(
inputs=input_ids,
attention_mask=attention_mask,
generation_config=self._transformers_generator_parameters,
+ tokenizer=self._tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
- stopping_criteria=self._stopping_criteria,
)
postprocessed_output_indices = self._postprocess(
@@ -102,18 +102,23 @@ def _generate_from_single_record(
if self._return_logits:
with torch.no_grad():
- logits = self._model(output_indices).logits
+ logits = self._model(output_indices).logits.cpu()
answer_tokens_ids = postprocessed_output_indices
input_token_ids = input_ids
- return ChatInferenceOutput(
- id=original_record.id,
- dataset_name=dataset_name,
- messages=original_record.messages,
- label=original_record.label,
- meta=original_record.meta,
- answers=[
+ answer_messages = [
+ AnswerMessage(
+ id=str(i),
+ content=a,
+ input_token_ids=input_token_ids,
+ answer_token_ids=a_t_ids.unsqueeze(0),
+ logits=l.unsqueeze(0),
+ )
+ for i, (a, a_t_ids, l) in enumerate(zip(answers, answer_tokens_ids, logits)) # type: ignore[arg-type]
+ ]
+ else:
+ answer_messages = [
AnswerMessage(
id=str(i),
content=a,
@@ -122,5 +127,13 @@ def _generate_from_single_record(
logits=logits,
)
for i, a in enumerate(answers)
- ],
+ ]
+
+ return ChatInferenceOutput(
+ id=original_record.id,
+ dataset_name=dataset_name,
+ messages=original_record.messages,
+ label=original_record.label,
+ meta=original_record.meta,
+ answers=answer_messages,
)
diff --git a/turbo_alignment/generators/multimodal.py b/turbo_alignment/generators/multimodal.py
index ece60e5..0c64137 100755
--- a/turbo_alignment/generators/multimodal.py
+++ b/turbo_alignment/generators/multimodal.py
@@ -50,6 +50,7 @@ def _generate_from_single_record(
output_indices = self._model.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
+ tokenizer=self._tokenizer,
generation_config=self._transformers_generator_parameters,
)
diff --git a/turbo_alignment/generators/rag.py b/turbo_alignment/generators/rag.py
index 3658b36..4203de8 100755
--- a/turbo_alignment/generators/rag.py
+++ b/turbo_alignment/generators/rag.py
@@ -22,12 +22,12 @@ def _generate_from_single_record(
answer_indices, document_indices, doc_scores = self._model.generate(
inputs=input_ids,
generation_config=self._transformers_generator_parameters,
+ tokenizer=self._tokenizer.current_tokenizer,
pad_token_id=self._tokenizer.pad_token_id,
- stopping_criteria=self._stopping_criteria,
)
- answers = self._decode(token_indices=answer_indices)
- documents = self._decode(token_indices=document_indices)
+ answers = self._decode(token_indices=answer_indices.cpu())
+ documents = self._decode(token_indices=document_indices.cpu())
doc_scores = list(doc_scores[0])
return RagInferenceOutput(
diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py
index ab3e07e..ce9a9c8 100755
--- a/turbo_alignment/generators/rm.py
+++ b/turbo_alignment/generators/rm.py
@@ -28,7 +28,7 @@ def _generate_from_batch(
attn_mask = batch['attention_mask'].to(self.device)
with torch.no_grad():
- rewards = self._model(input_ids=input_ids, attention_mask=attn_mask).logits
+ rewards = self._model(input_ids=input_ids, attention_mask=attn_mask).logits.cpu()
rewards_w, rewards_l = rewards[: len(records)], rewards[len(records) :]
return [
@@ -74,7 +74,7 @@ def _generate_from_batch(
for i in range(0, len(input_ids), self._micro_batch):
input_ids_batch = input_ids[i : i + self._micro_batch].to(self.device)
attn_mask_batch = attn_mask[i : i + self._micro_batch].to(self.device)
- rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits)
+ rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits.cpu())
rewards = torch.cat(rewards, dim=0)
diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py
index 03235c0..513c5e7 100755
--- a/turbo_alignment/generators/vllm_chat.py
+++ b/turbo_alignment/generators/vllm_chat.py
@@ -25,11 +25,15 @@ def __init__(
model.set_tokenizer(tokenizer)
super().__init__(model, tokenizer, batch=batch)
- eos_token_id: list[int] = self._tokenizer.encode(
- custom_generation_settings.generation_eos_token, add_special_tokens=False
- )
+ if isinstance(transformers_settings.stop_strings, list):
+ raise ValueError('You should use only 1 eos token with VLLM')
+
+ eos_token_id: list[int] = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False)
- beam_search_params = {'best_of': transformers_settings.num_return_sequences, 'use_beam_search': False}
+ beam_search_params: dict[str, Any] = {
+ 'best_of': transformers_settings.num_return_sequences,
+ 'use_beam_search': False,
+ }
if transformers_settings.num_beams > 1:
beam_search_params['use_beam_search'] = True
beam_search_params['best_of'] = transformers_settings.num_beams
@@ -43,7 +47,7 @@ def __init__(
skip_special_tokens=custom_generation_settings.skip_special_tokens,
stop_token_ids=eos_token_id,
max_tokens=transformers_settings.max_new_tokens,
- **beam_search_params, # type: ignore[arg-type]
+ **beam_search_params,
)
def _generate_from_batch(
diff --git a/turbo_alignment/metrics/__init__.py b/turbo_alignment/metrics/__init__.py
index 8d1c27d..c70bcea 100755
--- a/turbo_alignment/metrics/__init__.py
+++ b/turbo_alignment/metrics/__init__.py
@@ -5,7 +5,6 @@
from turbo_alignment.metrics.meteor import MeteorMetric
from turbo_alignment.metrics.metric import Metric
from turbo_alignment.metrics.perplexity import PerplexityMetric
-from turbo_alignment.metrics.ragas import RagasMetrics
from turbo_alignment.metrics.registry import *
from turbo_alignment.metrics.retrieval_utility import RetrievalUtilityMetric
from turbo_alignment.metrics.reward import RewardMetric
diff --git a/turbo_alignment/metrics/ragas.py b/turbo_alignment/metrics/ragas.py
deleted file mode 100644
index f73183b..0000000
--- a/turbo_alignment/metrics/ragas.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# pylint: skip-file
-# pylint: disable-all
-# mypy: ignore-errors
-
-from datasets import Dataset
-from ragas import RunConfig, evaluate
-from ragas.metrics import (
- answer_relevancy,
- answer_similarity,
- context_entity_recall,
- context_precision,
- context_recall,
- faithfulness,
-)
-
-from turbo_alignment.dataset.chat import InferenceChatDataset
-from turbo_alignment.metrics.metric import Metric
-from turbo_alignment.metrics.registry import RagasMetricsSettings
-from turbo_alignment.settings.generators.outputs.chat import RagInferenceOutput
-from turbo_alignment.settings.metric import MetricResults, MetricType
-
-
-@Metric.register(MetricType.RAGAS_METRICS)
-class RagasMetrics(Metric):
- def __init__(self, settings: RagasMetricsSettings) -> None:
- self._settings: RagasMetricsSettings = settings
-
- if self._settings.openai_api_key is not None:
- # use openai endpoints if api key is provided
- from langchain_openai import OpenAI, OpenAIEmbeddings
-
- self._llm = OpenAI(openai_api_key=self._settings.openai_api_key, model='gpt-3.5-turbo-instruct')
- self._embeddings = OpenAIEmbeddings(
- openai_api_key=self._settings.openai_api_key, model='text-embedding-3-large'
- )
-
- elif self._settings.mistralai_api_key is not None:
- from langchain_mistralai import MistralAIEmbeddings
- from langchain_mistralai.chat_models import ChatMistralAI
-
- self._llm = ChatMistralAI(name='mistral-large', api_key=self._settings.mistralai_api_key)
-
- self._embeddings = MistralAIEmbeddings(api_key=self._settings.mistralai_api_key)
-
- def compute(
- self, dataset: InferenceChatDataset, generations: list[RagInferenceOutput], **kwargs
- ) -> list[MetricResults]:
- questions = [d['messages'][0].content for d in dataset]
- ground_truths = [d['messages'][1].content for d in dataset]
- retieved_docs = [g.documents for g in generations]
- answers = [g.answers[0].content for g in generations]
-
- ragas_dataset = Dataset.from_dict(
- {'question': questions, 'ground_truth': ground_truths, 'contexts': retieved_docs, 'answer': answers}
- )
-
- extra_kwargs = {}
- if self._llm:
- extra_kwargs['llm'] = self._llm
- if self._embeddings:
- extra_kwargs['embeddings'] = self._embeddings
-
- results = evaluate(
- ragas_dataset,
- metrics=[
- faithfulness,
- answer_relevancy,
- answer_similarity,
- context_precision,
- context_recall,
- context_entity_recall,
- ],
- **extra_kwargs,
- raise_exceptions=False,
- is_async=True,
- run_config=RunConfig(max_workers=1, max_wait=180, thread_timeout=600),
- )
-
- return results
diff --git a/turbo_alignment/metrics/registry.py b/turbo_alignment/metrics/registry.py
index b3ebf2f..3c65eb4 100755
--- a/turbo_alignment/metrics/registry.py
+++ b/turbo_alignment/metrics/registry.py
@@ -1,7 +1,6 @@
from enum import Enum
from allenai_common import Registrable
-from pydantic import model_validator
from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel
from turbo_alignment.settings.datasets.chat import ChatPromptTemplate
@@ -87,19 +86,3 @@ class ToolMetricsSettings(MetricSettings):
@MetricSettingsRegistry.register(MetricType.RETRIEVAL_UTILITY)
class RetrievalUtilitySettings(MetricSettings):
doc_sep_symbol: str = ''
-
-
-@MetricSettingsRegistry.register(MetricType.RAGAS_METRICS)
-class RagasMetricsSettings(MetricSettings):
- openai_api_key: str | None = None
- mistralai_api_key: str | None = None
-
- @model_validator(mode='before')
- def check_only_one_field(cls, values):
- openai_api_key = values.get('openai_api_key')
- mistralai_api_key = values.get('mistralai_api_key')
-
- if not bool(openai_api_key) and not bool(mistralai_api_key):
- raise ValueError('At least one of openai_api_key or mistralai_api_key must be specified')
-
- return values
diff --git a/turbo_alignment/metrics/reward.py b/turbo_alignment/metrics/reward.py
index 851e8ee..dee37a1 100755
--- a/turbo_alignment/metrics/reward.py
+++ b/turbo_alignment/metrics/reward.py
@@ -47,7 +47,7 @@ def compute(self, **kwargs) -> list[MetricResults]:
messages = [record['messages'] for record in dataset.records for _ in range(answers_per_context)]
answers = [
- AnswerMessage(id=ans_idx, content=ans)
+ AnswerMessage(id=str(ans_idx), content=ans)
for ctx_answers in predictions
for ans_idx, ans in enumerate(ctx_answers)
]
diff --git a/turbo_alignment/modeling/common/transformer.py b/turbo_alignment/modeling/common/transformer.py
index 3582c64..eebc380 100755
--- a/turbo_alignment/modeling/common/transformer.py
+++ b/turbo_alignment/modeling/common/transformer.py
@@ -77,7 +77,16 @@ class MultiheadAttention(torch.nn.MultiheadAttention):
# pylint: disable=arguments-differ
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: # type: ignore[override]
- return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+ return super().forward(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=None,
+ need_weights=False,
+ attn_mask=attn_mask,
+ average_attn_weights=True,
+ is_causal=False,
+ )[0]
class BlockWithMasking(torch.nn.Module):
diff --git a/turbo_alignment/modeling/imagebind/heads/registry.py b/turbo_alignment/modeling/imagebind/heads/registry.py
index e1f1388..9dbd9d7 100755
--- a/turbo_alignment/modeling/imagebind/heads/registry.py
+++ b/turbo_alignment/modeling/imagebind/heads/registry.py
@@ -1,3 +1,5 @@
+from abc import ABC
+
import torch
from turbo_alignment.modeling.common.helpers import SelectElement, SelectEOSAndProject
@@ -7,7 +9,7 @@
)
-class Heads(torch.nn.ModuleDict): # pylint: disable=abstract-method
+class Heads(ABC, torch.nn.ModuleDict):
@staticmethod
def __get_vision_head(
vision_embed_dim: int,
diff --git a/turbo_alignment/modeling/imagebind/imagebind.py b/turbo_alignment/modeling/imagebind/imagebind.py
index 269dd02..e039f96 100755
--- a/turbo_alignment/modeling/imagebind/imagebind.py
+++ b/turbo_alignment/modeling/imagebind/imagebind.py
@@ -1,4 +1,3 @@
-# pylint: skip-file
# pylint: disable=unused-import
import torch
diff --git a/turbo_alignment/modeling/imagebind/postprocessors/registry.py b/turbo_alignment/modeling/imagebind/postprocessors/registry.py
index d0a7bcd..b3d631d 100755
--- a/turbo_alignment/modeling/imagebind/postprocessors/registry.py
+++ b/turbo_alignment/modeling/imagebind/postprocessors/registry.py
@@ -1,3 +1,5 @@
+from abc import ABC
+
import torch
from turbo_alignment.modeling.common.helpers import LearnableLogitScaling, Normalize
@@ -7,7 +9,7 @@
)
-class Postprocessors(torch.nn.ModuleDict): # pylint: disable=abstract-method
+class Postprocessors(ABC, torch.nn.ModuleDict):
def __init__(self, _settings: ImageBindArchitectureSettings):
super().__init__()
diff --git a/turbo_alignment/modeling/imagebind/preprocessors/impl.py b/turbo_alignment/modeling/imagebind/preprocessors/impl.py
index 25d3a4a..ced8476 100755
--- a/turbo_alignment/modeling/imagebind/preprocessors/impl.py
+++ b/turbo_alignment/modeling/imagebind/preprocessors/impl.py
@@ -177,12 +177,11 @@ def get_pos_embedding(self, vision_input, all_vision_tokens):
class RGBDTPreprocessor(VerboseNNModule):
- # pylint: disable=dangerous-default-value
def __init__(
self,
rgbt_stem: PatchEmbedGeneric | None,
depth_stem: PatchEmbedGeneric | None,
- img_size: list = [3, 224, 224],
+ img_size: list | None = None,
num_cls_tokens: int = 1,
pos_embed_fn: Callable | None = None,
use_type_embed: bool = False,
@@ -190,6 +189,8 @@ def __init__(
) -> None:
super().__init__()
stem = rgbt_stem if rgbt_stem is not None else depth_stem
+ if img_size is None:
+ img_size = [3, 224, 224]
assert stem is not None
(
self.patches_layout,
@@ -277,11 +278,10 @@ def forward(self, vision: torch.Tensor | None = None, depth: torch.Tensor | None
class AudioPreprocessor(RGBDTPreprocessor):
- # pylint: disable=arguments-differ
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
- def forward(self, audio: torch.Tensor | None = None) -> dict: # type: ignore [override]
+ def forward(self, *_args, audio: torch.Tensor | None = None, **_kwargs) -> dict:
# vision here is actually audio
return super().forward(vision=audio)
diff --git a/turbo_alignment/modeling/imagebind/preprocessors/registry.py b/turbo_alignment/modeling/imagebind/preprocessors/registry.py
index ae5e65d..d5e6043 100755
--- a/turbo_alignment/modeling/imagebind/preprocessors/registry.py
+++ b/turbo_alignment/modeling/imagebind/preprocessors/registry.py
@@ -1,3 +1,4 @@
+from abc import ABC
from functools import partial
import torch
@@ -16,7 +17,7 @@
)
-class Preprocessors(torch.nn.ModuleDict): # pylint: disable=abstract-method
+class Preprocessors(ABC, torch.nn.ModuleDict):
@staticmethod
def __get_rgbt_preprocessor(
video_frames: int,
diff --git a/turbo_alignment/modeling/imagebind/trunks/registry.py b/turbo_alignment/modeling/imagebind/trunks/registry.py
index 5a5bf6b..f3e113f 100755
--- a/turbo_alignment/modeling/imagebind/trunks/registry.py
+++ b/turbo_alignment/modeling/imagebind/trunks/registry.py
@@ -1,3 +1,4 @@
+from abc import ABC
from functools import partial
import torch
@@ -13,7 +14,7 @@
)
-class Trunks(torch.nn.ModuleDict): # pylint: disable=abstract-method
+class Trunks(ABC, torch.nn.ModuleDict):
@staticmethod
def __instantiate_trunk(
embed_dim: int,
diff --git a/turbo_alignment/modeling/multimodal/encoders/image/clip.py b/turbo_alignment/modeling/multimodal/encoders/image/clip.py
index 4d7b9bc..d803c17 100644
--- a/turbo_alignment/modeling/multimodal/encoders/image/clip.py
+++ b/turbo_alignment/modeling/multimodal/encoders/image/clip.py
@@ -25,8 +25,7 @@ def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, i
def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor:
if is_pickle:
return inputs
- # pylint: disable=line-too-long
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213 # noqa: E501
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213
# -2 is default value of vision_feature_layer in llava config
# [1:] is everything after vit [cls] token
return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][
diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py
index 21de1e0..69c5055 100644
--- a/turbo_alignment/modeling/multimodal/lm/projection.py
+++ b/turbo_alignment/modeling/multimodal/lm/projection.py
@@ -1,4 +1,5 @@
from collections import defaultdict
+from pathlib import Path
import torch
from scipy.ndimage.measurements import find_objects, label
@@ -40,14 +41,13 @@ def __init__(self, *args, **kwargs) -> None:
Please, set n_modality_embs to {self.encoders[modality].n_modality_embs} in config.'
if self.modality_projector_initialization_mapping:
- if self.modality_projector_initialization_mapping.get(modality):
+ state_dict_path: Path | None = self.modality_projector_initialization_mapping.get(modality, None)
+ if state_dict_path is not None:
logger.info(f'Loading {modality} connector weights')
- state_dictionary = torch.load(
- self.modality_projector_initialization_mapping[modality] # type: ignore[arg-type]
- )
+ state_dictionary = torch.load(state_dict_path)
modality_adapters[modality].load_state_dict(state_dictionary)
- logger.info(f'Sucsessfully loaded from {self.modality_projector_initialization_mapping[modality]}')
+ logger.info(f'Sucsessfully loaded from {state_dict_path}')
self.modality_adapters = torch.nn.ModuleDict(modality_adapters)
@@ -122,7 +122,6 @@ def forward(
labels: torch.LongTensor | None = None,
) -> ModelOutput:
multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask)
-
return self.language_model(
inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask
)
diff --git a/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py b/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py
index ed4fbf0..7bf08df 100644
--- a/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py
+++ b/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py
@@ -1,3 +1,4 @@
+from abc import ABC
from functools import partial
import torch
@@ -14,6 +15,7 @@
class HoneybeeVisualProjectorConfig(BaseModel):
projector_type: str = 'c-abs'
+ initializer_range: float = 1.0
depth: int = 3
mlp_depth: int = 2
hidden_size: int = 1024
@@ -42,7 +44,7 @@ def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size:
num_eos_tokens = config.num_eos_tokens
if num_eos_tokens:
eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size))
- torch.nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) # type: ignore
+ torch.nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range)
else:
eos_tokens = None
@@ -58,9 +60,9 @@ def build_prenorm(config: HoneybeeVisualProjectorConfig):
def build_mlp(depth: int, hidden_size: int, output_hidden_size: int):
- layers = [torch.nn.Linear(hidden_size, output_hidden_size)]
+ layers: list[torch.nn.Module] = [torch.nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
- layers.append(torch.nn.SiLU()) # type: ignore
+ layers.append(torch.nn.SiLU())
layers.append(torch.nn.Linear(output_hidden_size, output_hidden_size))
return torch.nn.Sequential(*layers)
@@ -131,8 +133,7 @@ def _load_from_state_dict(self, state_dict, *args, **kwargs):
super()._load_from_state_dict(state_dict, *args, **kwargs)
-# pylint: disable=abstract-method
-class ConvProjector(Projector):
+class ConvProjector(ABC, Projector):
def _forward(self, x):
# x: [B, L, dim]
hw = int(x.size(1) ** 0.5)
diff --git a/turbo_alignment/modeling/rag/rag_model.py b/turbo_alignment/modeling/rag/rag_model.py
index 9a58c2a..9bd926b 100755
--- a/turbo_alignment/modeling/rag/rag_model.py
+++ b/turbo_alignment/modeling/rag/rag_model.py
@@ -296,7 +296,7 @@ def generate(
self,
inputs: torch.LongTensor,
generation_config: GenerationConfig | None = None,
- **kwargs, # pylint: disable=unused-argument
+ **kwargs,
) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]:
# TODO remove code duplicate with forward
@@ -343,7 +343,7 @@ def generate(
input_ids=joined_input_ids,
generation_config=generation_config,
pad_token_id=self.tokenizer.pad_token_id,
- stopping_criteria=kwargs['stopping_criteria'],
+ tokenizer=kwargs.get('tokenizer', None),
)
# TODO chose max-prob sequence with accounting for doc probs
only_answer_output = output_sequences[:, joined_input_ids.shape[-1] :]
diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py
index c9c7ba8..f85a65d 100644
--- a/turbo_alignment/pipelines/preprocessing/multimodal.py
+++ b/turbo_alignment/pipelines/preprocessing/multimodal.py
@@ -47,7 +47,7 @@ def _read_modality_objects(self, reader: BaseModalityReader, dataset_path: Path)
total_number_of_objects = len(list(dataset_path.iterdir()))
logger.info('📖 Reading modality objects...')
- files_paths = [] # type: ignore
+ files_paths: list[Path] = []
for extension in available_extensions:
files_paths.extend(dataset_path.glob(f'*.{extension}'))
@@ -121,6 +121,7 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
experiment_settings.output_file_path.mkdir(parents=True, exist_ok=True)
tensors = self._get_safetensor_dict(encoded_modality_tensors, encoded_file_paths)
+
del encoded_modality_tensors
save_file(
diff --git a/turbo_alignment/pipelines/train/ddpo.py b/turbo_alignment/pipelines/train/ddpo.py
index 8a35a38..a280b6f 100755
--- a/turbo_alignment/pipelines/train/ddpo.py
+++ b/turbo_alignment/pipelines/train/ddpo.py
@@ -94,7 +94,8 @@ def _get_trainer(
data_collator: Callable,
rm_model: PreTrainedModel = None,
) -> DDPOTrainer:
- model.config.use_cache = False
+ model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
+
extra_args = {'rm': rm_model}
return DDPOTrainer(
diff --git a/turbo_alignment/pipelines/train/dpo.py b/turbo_alignment/pipelines/train/dpo.py
index 7290edd..155f6d3 100755
--- a/turbo_alignment/pipelines/train/dpo.py
+++ b/turbo_alignment/pipelines/train/dpo.py
@@ -73,6 +73,7 @@ def _get_trainer(
data_collator: Callable,
):
model.config.use_cache = not training_args.gradient_checkpointing
+
extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
ref_model = load_model(experiment_settings.model_settings, tokenizer)
diff --git a/turbo_alignment/pipelines/train/kto.py b/turbo_alignment/pipelines/train/kto.py
index 20c8c5f..b34d813 100755
--- a/turbo_alignment/pipelines/train/kto.py
+++ b/turbo_alignment/pipelines/train/kto.py
@@ -72,6 +72,8 @@ def _get_trainer(
val_dataset: Dataset,
data_collator: Callable,
):
+ model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
+
extra_args = {}
if experiment_settings.trainer_settings.use_ref_model:
ref_model = load_model(experiment_settings.model_settings, tokenizer)
diff --git a/turbo_alignment/pipelines/train/sft.py b/turbo_alignment/pipelines/train/sft.py
index 85b5346..a1bddec 100755
--- a/turbo_alignment/pipelines/train/sft.py
+++ b/turbo_alignment/pipelines/train/sft.py
@@ -72,6 +72,8 @@ def _get_trainer(
data_collator: DataCollatorMixin,
**_kwargs,
) -> MultiGPUCherryPicksTrainer:
+ model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing
+
return MultiGPUCherryPicksTrainer(
model=model,
args=training_args,
diff --git a/turbo_alignment/settings/generators/chat.py b/turbo_alignment/settings/generators/chat.py
index 3869baf..01bcc09 100755
--- a/turbo_alignment/settings/generators/chat.py
+++ b/turbo_alignment/settings/generators/chat.py
@@ -3,6 +3,5 @@
class CustomChatGenerationSettings(ExtraFieldsNotAllowedBaseModel):
skip_special_tokens: bool = True
- generation_eos_token: str = ''
remove_prompt: bool = True
batch: int = 1
diff --git a/turbo_alignment/settings/metric.py b/turbo_alignment/settings/metric.py
index ee71e44..a0cecac 100755
--- a/turbo_alignment/settings/metric.py
+++ b/turbo_alignment/settings/metric.py
@@ -19,8 +19,6 @@ class MetricType(str, Enum):
KL: str = 'kl'
TOOL_CALL_METRICS: str = 'tool_call_metrics'
RETRIEVAL_UTILITY: str = 'retrieval_utility'
- INTENT_CLASSIFIER_ACCURACY: str = 'intent_classifier_accuracy'
- RAGAS_METRICS: str = 'ragas_metrics'
class ElementWiseScores(ExtraFieldsNotAllowedBaseModel):
diff --git a/turbo_alignment/settings/pipelines/train/base.py b/turbo_alignment/settings/pipelines/train/base.py
index e9f8227..7ed9d66 100755
--- a/turbo_alignment/settings/pipelines/train/base.py
+++ b/turbo_alignment/settings/pipelines/train/base.py
@@ -25,12 +25,10 @@ class BaseTrainExperimentSettings(BaseSettings):
log_path: Path = Path('train_output')
seed: int = 42
- # early_stopping: EarlyStoppingSettings | None = None
-
trainer_settings: TrainerSettings
tokenizer_settings: TokenizerSettings
- model_settings: ModelForPeftSettings | PreTrainedModelSettings | PreTrainedAdaptersModelSettings
+ model_settings: (ModelForPeftSettings | PreTrainedModelSettings | PreTrainedAdaptersModelSettings)
train_dataset_settings: MultiDatasetSettings
val_dataset_settings: MultiDatasetSettings
diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py
index c7a0ae0..64d0a43 100755
--- a/turbo_alignment/settings/pipelines/train/dpo.py
+++ b/turbo_alignment/settings/pipelines/train/dpo.py
@@ -19,6 +19,8 @@ class DPOLossesType(str, Enum):
KTO = 'kto'
SLIC_HF = 'slic_hf'
CPO = 'cpo'
+ ORPO = 'orpo'
+ SIMPO = 'simpo'
class DPOLossSettings(ExtraFieldsNotAllowedBaseModel):
@@ -57,6 +59,17 @@ class SlicHfLossSettings(DPOLossSettings):
norm: bool = False
+class SimPOLossSettings(DPOLossSettings):
+ loss_type: Literal[DPOLossesType.SIMPO]
+ beta: float = 0.1
+ gamma: float = 0.1
+
+
+class ORPOLossSettings(DPOLossSettings):
+ loss_type: Literal[DPOLossesType.ORPO]
+ beta: float = 0.1
+
+
class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel):
sync_ref_model: bool = False
alpha: float = 1.0
@@ -64,19 +77,22 @@ class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel):
class DPOTrainerSettings(TrainerSettings):
- loss_settings: SigmoidLossSettings | HingeLossSettings | IPOLossSettings | KTOLossSettings | CPOLossSettings
+ loss_settings: (
+ SigmoidLossSettings
+ | HingeLossSettings
+ | IPOLossSettings
+ | KTOLossSettings
+ | CPOLossSettings
+ | ORPOLossSettings
+ | SimPOLossSettings
+ | SlicHfLossSettings
+ )
sync_ref_settings: SyncRefModelSettings
use_ref_model: bool = True
use_sft_model: bool = False
average_log_prob: bool = Field(default=False, description='Normalize log probability by length or not')
-class SlicHfTrainerSettings(TrainerSettings):
- loss_settings: SlicHfLossSettings
- use_ref_model: bool = True
- average_log_prob: bool = Field(default=False, description='Normalize log probability by length or not')
-
-
class DPOTrainExperimentSettings(BaseTrainExperimentSettings):
train_dataset_settings: PairPreferenceMultiDatasetSettings
val_dataset_settings: PairPreferenceMultiDatasetSettings
diff --git a/turbo_alignment/settings/tf/generation.py b/turbo_alignment/settings/tf/generation.py
index ee977ff..14537d9 100755
--- a/turbo_alignment/settings/tf/generation.py
+++ b/turbo_alignment/settings/tf/generation.py
@@ -10,3 +10,4 @@ class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel):
top_p: float = 1.0
top_k: int = 50
temperature: float = 1.0
+ stop_strings: str | list[str] = ''
diff --git a/turbo_alignment/settings/tf/peft.py b/turbo_alignment/settings/tf/peft.py
index b826760..f479a27 100755
--- a/turbo_alignment/settings/tf/peft.py
+++ b/turbo_alignment/settings/tf/peft.py
@@ -11,7 +11,7 @@ class BasePeftSettings(ExtraFieldsNotAllowedBaseModel):
class LoraSettings(BasePeftSettings):
- name: Literal[PeftType.LORA] = PeftType.LORA # type: ignore[valid-type]
+ name: Literal[PeftType.LORA] = PeftType.LORA
r: int = 16
lora_alpha: int = 16
lora_dropout: float = 0.05
@@ -21,19 +21,19 @@ class LoraSettings(BasePeftSettings):
class PrefixTuningSettings(BasePeftSettings):
- name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING # type: ignore[valid-type]
+ name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING
encoder_hidden_size: int
prefix_projection: bool
class PromptTuningSettings(BasePeftSettings):
- name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING # type: ignore[valid-type]
+ name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING
num_virtual_tokens: int = 32
prompt_tuning_init_text: str | None = None
class PTuningSettings(BasePeftSettings):
- name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING # type: ignore[valid-type]
+ name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING
num_virtual_tokens: int = 32
encoder_reparameterization_type: str = 'MLP'
diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py
index 8d91070..ff17c91 100755
--- a/turbo_alignment/trainers/dpo.py
+++ b/turbo_alignment/trainers/dpo.py
@@ -1,4 +1,3 @@
-# mypy: disable-error-code="call-overload"
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Literal
@@ -25,10 +24,15 @@
from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelCallback
from turbo_alignment.constants import DISABLE_LOSS_LABEL
from turbo_alignment.settings.pipelines.train.dpo import (
+ CPOLossSettings,
DPOLossesType,
HingeLossSettings,
IPOLossSettings,
+ KTOLossSettings,
+ ORPOLossSettings,
SigmoidLossSettings,
+ SimPOLossSettings,
+ SlicHfLossSettings,
SyncRefModelSettings,
)
from turbo_alignment.trainers.utils import (
@@ -198,6 +202,7 @@ def compute_loss(
return loss, chosen_rewards, rejected_rewards
+@DPOLossRegistry.register(DPOLossesType.SLIC_HF)
class SlicHfLoss(DPOLossRegistry):
def __init__(self, delta: float = 1, beta: float = 1.0, lam: float = 1.0, norm: bool = False) -> None:
self.delta = delta
@@ -232,12 +237,83 @@ def compute_loss(
return loss, chosen_rewards, rejected_rewards
+@DPOLossRegistry.register(DPOLossesType.SIMPO)
+class SimPOLoss(DPOLossRegistry):
+ def __init__(self, *args, beta: float = 0.1, gamma: float = 0.1, **kwargs) -> None:
+ self.beta = beta
+ self.gamma = gamma
+ super().__init__(*args, **kwargs)
+
+ def compute_loss(
+ self,
+ policy_chosen_logps: torch.FloatTensor,
+ policy_rejected_logps: torch.FloatTensor,
+ reference_chosen_logps: torch.FloatTensor | None,
+ reference_rejected_logps: torch.FloatTensor | None,
+ policy_best_decode_logps: torch.FloatTensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
+
+ logits = pi_logratios - self.gamma
+
+ chosen_rewards = self.beta * (policy_chosen_logps).detach()
+ rejected_rewards = self.beta * (policy_rejected_logps).detach()
+
+ loss = -F.logsigmoid(self.beta * logits)
+
+ return (
+ loss,
+ chosen_rewards,
+ rejected_rewards,
+ )
+
+
+@DPOLossRegistry.register(DPOLossesType.ORPO)
+class ORPOLoss(DPOLossRegistry):
+ def __init__(self, *args, beta: float = 0.1, **kwargs) -> None:
+ self.beta = beta
+ super().__init__(*args, **kwargs)
+
+ def compute_loss(
+ self,
+ policy_chosen_logps: torch.FloatTensor,
+ policy_rejected_logps: torch.FloatTensor,
+ reference_chosen_logps: torch.FloatTensor | None,
+ reference_rejected_logps: torch.FloatTensor | None,
+ policy_best_decode_logps: torch.FloatTensor | None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
+ torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7))
+ - torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7))
+ )
+
+ ratio = -F.logsigmoid(log_odds)
+ losses = self.beta * ratio
+
+ chosen_rewards = self.beta * policy_chosen_logps.detach()
+ rejected_rewards = self.beta * policy_rejected_logps.detach()
+
+ return losses, chosen_rewards, rejected_rewards
+
+
@dataclass
class DPOTrainingArguments(TrainingArguments):
- loss_settings: SigmoidLossSettings | HingeLossSettings | IPOLossSettings | SlicHfLoss | KTOLoss | CPOLoss = field(
+ loss_settings: (
+ SigmoidLossSettings
+ | HingeLossSettings
+ | IPOLossSettings
+ | SlicHfLossSettings
+ | KTOLossSettings
+ | CPOLossSettings
+ | ORPOLossSettings
+ | SimPOLossSettings
+ | SlicHfLossSettings
+ ) = field(
default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID)
+ ) # type: ignore[call-overload]
+ sync_ref_settings: SyncRefModelSettings = field( # type: ignore[call-overload]
+ default_factory=SyncRefModelSettings()
)
- sync_ref_settings: SyncRefModelSettings = field(default_factory=SyncRefModelSettings())
use_ref_model: bool = True
use_sft_model: bool = False
average_log_prob: bool = False
@@ -269,6 +345,10 @@ def __init__(
if hasattr(args, 'loss_settings'):
self.loss_type = args.loss_settings['loss_type'] # type: ignore[index]
+
+ if self.loss_type in (DPOLossesType.SIMPO, DPOLossesType.ORPO) and not args.average_log_prob:
+ raise ValueError(f'You should normalize logits by length when using {self.loss_type}')
+
loss_args = args.loss_settings
loss_args.pop('loss_type') # type: ignore[union-attr]
self.dpo_loss_registry = DPOLossRegistry.by_name(self.loss_type)(**loss_args)
@@ -286,6 +366,9 @@ def __init__(
**kwargs,
)
+ if hasattr(args, 'loss_settings') and self.loss_type in (DPOLossesType.SIMPO, DPOLossesType.ORPO):
+ logger.info(f'You can turn off ref_model when using {self.loss_type} for memory saving')
+
self.ref_model = ref_model
self.sft_model = sft_model
@@ -407,7 +490,10 @@ def get_batch_metrics(
policy_best_decode_logps,
) = self.concatenated_forward(model, batch)
- reference_chosen_logps, reference_rejected_logps = self._get_logps(self.ref_model, batch)
+ reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')])
+
+ if self.args.use_ref_model or self.loss_type not in (DPOLossesType.SIMPO, DPOLossesType.ORPO):
+ reference_chosen_logps, reference_rejected_logps = self._get_logps(self.ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps=policy_chosen_logps,
@@ -423,14 +509,16 @@ def get_batch_metrics(
metrics = self._compute_metrics(metrics, dpo_prefix_name, chosen_rewards, rejected_rewards)
- metrics[f'{prefix}logps/ref_rejected'] = (reference_rejected_logps).detach().cpu().mean().item()
- metrics[f'{prefix}logps/ref_chosen'] = (reference_chosen_logps).detach().cpu().mean().item()
metrics[f'{prefix}logps/rejected'] = (policy_rejected_logps).detach().cpu().mean().item()
metrics[f'{prefix}logps/chosen'] = (policy_chosen_logps).detach().cpu().mean().item()
metrics[f'{prefix}logits/rejected'] = (policy_rejected_logits).detach().cpu().mean().item()
metrics[f'{prefix}logits/chosen'] = (policy_chosen_logits).detach().cpu().mean().item()
+ if self.args.use_ref_model:
+ metrics[f'{prefix}logps/ref_rejected'] = (reference_rejected_logps).detach().cpu().mean().item()
+ metrics[f'{prefix}logps/ref_chosen'] = (reference_chosen_logps).detach().cpu().mean().item()
+
if self.loss_type == DPOLossesType.KTO:
kto_chosen_KL = (
(policy_chosen_logps.detach().cpu() - reference_chosen_logps.detach().cpu()).mean().clamp(min=0)
@@ -448,6 +536,22 @@ def get_batch_metrics(
metrics[f'{prefix}rewards/kto_grad_term_chosen'] = kto_grad_term_chosen.item()
metrics[f'{prefix}rewards/kto_grad_term_rejected'] = kto_grad_term_rejected.item()
+ elif self.loss_type == DPOLossesType.ORPO:
+ labels_w = batch['inputs_w']['labels'][:, 1:].clone()
+ loss_mask_w = labels_w != DISABLE_LOSS_LABEL
+ length_norm_policy_chosen_logps = policy_chosen_logps / loss_mask_w.sum(-1)
+
+ log_odds = (policy_chosen_logps - policy_rejected_logps) - (
+ torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7))
+ - torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7))
+ )
+ ratio = -F.logsigmoid(log_odds)
+ nll_loss = -length_norm_policy_chosen_logps
+
+ metrics[f'{prefix}orpo/nll_loss'] = nll_loss.clone().detach().cpu().mean().item()
+ metrics[f'{prefix}orpo/ratio'] = (ratio).detach().cpu().mean().item()
+ metrics[f'{prefix}orpo/log_odds'] = (log_odds).detach().cpu().mean().item()
+
if self.sft_model is not None:
sft_chosen_logps, sft_rejected_logps = self._get_logps(self.sft_model, batch)
@@ -463,7 +567,7 @@ def get_batch_metrics(
sft_prefix_name = prefix + 'rewards/sft_'
metrics = self._compute_metrics(metrics, sft_prefix_name, sft_chosen_rewards, sft_rejected_rewards)
- return losses.mean(), metrics # type: ignore
+ return losses.mean(), metrics
def _compute_metrics(
self, metrics: dict[str, float], prefix_name: str, chosen_rewards: torch.Tensor, rejected_rewards: torch.Tensor
@@ -523,7 +627,7 @@ def prediction_step(
'logits_test/rejected': metrics['logits_test/rejected'],
}
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
- logits = torch.stack(logits).mean(axis=1) # type: ignore[arg-type, call-overload]
+ logits = torch.stack(logits).mean(axis=1) # type: ignore[call-overload, arg-type]
labels = torch.zeros(logits.shape[0])
return loss.detach(), logits, labels
diff --git a/turbo_alignment/trainers/kto.py b/turbo_alignment/trainers/kto.py
index 29adceb..e3e7c6e 100755
--- a/turbo_alignment/trainers/kto.py
+++ b/turbo_alignment/trainers/kto.py
@@ -31,8 +31,8 @@
class KTOTrainingArguments(TrainingArguments):
beta: float = 0.1
sync_ref_settings: SyncRefModelSettings = field(
- default_factory=SyncRefModelSettings() # type: ignore[call-overload]
- )
+ default_factory=SyncRefModelSettings()
+ ) # type: ignore[call-overload]
use_ref_model: bool = True
average_log_prob: bool = False
undesirable_weight: float = 1.0
diff --git a/turbo_alignment/trainers/multigpu.py b/turbo_alignment/trainers/multigpu.py
index 3a75e39..7c84076 100755
--- a/turbo_alignment/trainers/multigpu.py
+++ b/turbo_alignment/trainers/multigpu.py
@@ -1,7 +1,5 @@
-import os
from typing import Callable
-import numpy as np
from torch import nn
from torch.utils.data import Dataset
from transformers import (
@@ -16,7 +14,6 @@
TrainingArguments,
)
from transformers.integrations import get_reporting_integration_callbacks
-from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from turbo_alignment.common.tf.callbacks.common import WandbMetricsCallbackHandler
@@ -61,56 +58,3 @@ def __init__(
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
self.control: TrainerControl = self.callback_handler.on_init_end(self.args, self.state, self.control)
-
- def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unused-argument
- # FIX: https://github.com/huggingface/transformers/issues/28119
- # https://github.com/huggingface/transformers/pull/29370
-
- # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
- # want to save except FullyShardedDDP.
- # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
-
- # Save model checkpoint
- checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
-
- if self.hp_search_backend is None and trial is None:
- self.store_flos()
-
- run_dir = self._get_output_dir(trial=trial)
- output_dir = os.path.join(run_dir, checkpoint_folder)
- self.save_model(output_dir, _internal_call=True)
-
- if not self.args.save_only_model:
- # Save optimizer and scheduler
- self._save_optimizer_and_scheduler(output_dir)
- # Save RNG state
- self._save_rng_state(output_dir)
-
- # Determine the new best metric / best model checkpoint
- if metrics is not None and self.args.metric_for_best_model is not None:
- metric_to_check = self.args.metric_for_best_model
- if not metric_to_check.startswith('eval_'):
- metric_to_check = f'eval_{metric_to_check}'
- metric_value = metrics[metric_to_check]
-
- operator = np.greater if self.args.greater_is_better else np.less
- if (
- self.state.best_metric is None
- or self.state.best_model_checkpoint is None
- or operator(metric_value, self.state.best_metric)
- ):
- self.state.best_metric = metric_value
- self.state.best_model_checkpoint = output_dir
-
- # Save the Trainer state
- if self.args.should_save:
- self.state.save_to_json(os.path.join(output_dir, 'trainer_state.json'))
-
- if self.args.push_to_hub:
- self._push_from_checkpoint(output_dir)
-
- # Maybe delete some older checkpoints.
- if self.args.should_save:
- # Solely rely on numerical checkpoint id for rotation.
- # mtime is not reliable especially on some fuse fs in cloud environments.
- self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
diff --git a/turbo_alignment/trainers/multimodal.py b/turbo_alignment/trainers/multimodal.py
index ea87f21..59189dc 100755
--- a/turbo_alignment/trainers/multimodal.py
+++ b/turbo_alignment/trainers/multimodal.py
@@ -26,7 +26,7 @@
class TrainerCustomSave(MultiGPUCherryPicksTrainer):
- def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unused-argument
+ def _save_checkpoint(self, model, trial, metrics=None):
logger.info('Running custom _save_checkpoint')
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
run_dir = self._get_output_dir(trial=trial)
@@ -53,9 +53,15 @@ def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unuse
(output_dir / 'adapter').mkdir(parents=True, exist_ok=True)
(output_dir / 'tokenizer').mkdir(parents=True, exist_ok=True)
- torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt')
+ if isinstance(model, torch.nn.DataParallel):
+ torch.save(
+ model.module.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt'
+ )
+ model.module.language_model.save_pretrained(output_dir / 'adapter')
+ else:
+ torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt')
+ model.language_model.save_pretrained(output_dir / 'adapter')
- model.language_model.save_pretrained(output_dir / 'adapter')
self.tokenizer.save_pretrained(output_dir / 'tokenizer')
@@ -269,7 +275,7 @@ def prediction_step(
'logits_test/rejected': metrics['logits_test/rejected'],
}
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
- logits = torch.stack(logits).mean(axis=1) # type: ignore[arg-type, call-overload]
+ logits = torch.stack(logits).mean(axis=1) # type: ignore[call-overload, arg-type]
labels = torch.zeros(logits.shape[0])
return loss.detach(), logits, labels
@@ -285,7 +291,7 @@ def log(self, logs: Dict[str, float]) -> None:
del self._stored_metrics[train_eval]
return super().log(logs) # pylint: disable=no-member
- def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unused-argument
+ def _save_checkpoint(self, model, trial, metrics=None):
logger.info('Running custom _save_checkpoint')
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
run_dir = self._get_output_dir(trial=trial)
diff --git a/turbo_alignment/trainers/utils.py b/turbo_alignment/trainers/utils.py
index ebd054a..8021d0a 100755
--- a/turbo_alignment/trainers/utils.py
+++ b/turbo_alignment/trainers/utils.py
@@ -54,10 +54,10 @@ def prepare_model_for_deepspeed(
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, 'config'):
- hidden_size: int | None = ( # type: ignore
- max(model.config.hidden_sizes) # type: ignore
- if getattr(model.config, 'hidden_sizes', None) # type: ignore
- else getattr(model.config, 'hidden_size', None) # type: ignore
+ hidden_size: int | None = (
+ max(model.config.hidden_sizes)
+ if getattr(model.config, 'hidden_sizes', None)
+ else getattr(model.config, 'hidden_size', None)
)
if hidden_size is not None and config_kwargs['zero_optimization']['stage'] == 3:
diff --git a/tutorials/multimodal/create_tutorial_dataset.py b/tutorials/multimodal/create_tutorial_dataset.py
new file mode 100644
index 0000000..ee7582d
--- /dev/null
+++ b/tutorials/multimodal/create_tutorial_dataset.py
@@ -0,0 +1,43 @@
+import json
+import random
+import subprocess
+from pathlib import Path
+from typing import Any
+
+from datasets import load_dataset
+
+from turbo_alignment.common.data.io import write_jsonl
+from turbo_alignment.dataset.chat.models import ChatMessageRole
+from turbo_alignment.dataset.multimodal.models import (
+ MultimodalChatMessage,
+ MultimodalDatasetRecord,
+ MultimodalImageMessage,
+ MultimodalTextMessage,
+)
+from turbo_alignment.settings.modality import Modality
+
+
+def convert_to_multimodal_record(row):
+ return MultimodalDatasetRecord(
+ id=row['id'],
+ messages=[
+ MultimodalImageMessage(role=ChatMessageRole.USER, content=f"images/00000/{str(row['id']).zfill(9)}.jpg"),
+ MultimodalTextMessage(role=ChatMessageRole.BOT, content=row['image_descriptions'][0].strip()),
+ ],
+ ).dict()
+
+
+if __name__ == '__main__':
+ dataset = load_dataset('passing2961/photochat_plus')['train']
+ dataset = dataset.add_column('id', range(len(dataset)))
+ dataset = dataset.train_test_split(test_size=0.1)
+
+ dataset['train_multimodal_records'] = dataset['train'].map(
+ convert_to_multimodal_record, remove_columns=dataset['train'].column_names
+ )
+ dataset['val_multimodal_records'] = dataset['test'].map(
+ convert_to_multimodal_record, remove_columns=dataset['test'].column_names
+ )
+
+ write_jsonl([item for item in dataset['train_multimodal_records']], Path('train_multimodal.jsonl'))
+ write_jsonl([item for item in dataset['val_multimodal_records']], Path('val_multimodal.jsonl'))
diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json
new file mode 100644
index 0000000..6910ab3
--- /dev/null
+++ b/tutorials/multimodal/multimodal.json
@@ -0,0 +1,196 @@
+{
+ "train_dataset_settings": {
+ "sources": [
+ {
+ "name": "train",
+ "records_path": "train_chat.jsonl",
+ "num_samples": 10
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "",
+ "user": "",
+ "system": ""
+ },
+ "prefix_template": "{role}",
+ "suffix_template": ""
+ },
+ "modality_token_mapping": {
+ "image": "",
+ "audio": "