diff --git a/Makefile b/Makefile index 8fa7f59..94724e2 100644 --- a/Makefile +++ b/Makefile @@ -39,6 +39,9 @@ ifneq ($(CODE),) unify --in-place --recursive $(CODE) tutorials endif +bash-dev: + docker run -v $(shell pwd):/app --rm -it ${DEV_IMAGE} bash + lock: rm poetry.lock poetry lock diff --git a/README.md b/README.md index 0ce8668..5ed343f 100755 --- a/README.md +++ b/README.md @@ -56,7 +56,6 @@ Turbo-Alignment supports a wide range of methods for model training and alignmen - **📏** Length - **🌀** Perplexity - **🌟** METEOR -- **📐** RAGAS - **🔍** Retrieval Utility @@ -86,7 +85,7 @@ Examples of datasets are available [here](docs/dataset_example.md). - [RSO](#-RSO-sampling) - [Common](#-common) - [Preprocess](#-preprocess-common) - - [Convert to base](#-convert-to-base-common) + - [Merge adapters to base](#-merge-adapters-to-base-common) # Train @@ -170,7 +169,9 @@ To launch RAG: ## 🚀 Installation ### 📦 Python Package -⌛️ in progress.. +```bash +pip install turbo-alignment +``` ### 🛠️ From Source For the latest features before an official release: diff --git a/docs/dataset_example.md b/docs/dataset_example.md index 7ca2f81..6ea91c5 100644 --- a/docs/dataset_example.md +++ b/docs/dataset_example.md @@ -11,7 +11,7 @@ - [Pair Preferences Dataset](#-pair-preferences-dataset) - [KTO Dataset](#-kto-dataset) - [Sampling Dataset](#-sampling-dataset) -- [Multimodal Dataset ](#-multimodal-dataset) (⌛️ Work in progress...) +- [Multimodal Dataset ](#-multimodal-dataset) - [Classification Dataset](#-classification-dataset) - [DPPO Dataset](#-ddpo-dataset) (⌛️ Work in progress...) @@ -118,9 +118,45 @@ Example: ## Multimodal Dataset -⌛️ in progress.. +- `messages`: `list[MultimodalChatMessage]` — This is a sequence of messages that make up the chat history. Each `ChatMessage` includes: + - `role` - The participant's role in the conversation (e.g., `user` or `bot`). + - `type` – The type of modality (e.g., `text` or `image`) + - `content` - If the `type` is `text`, it's the textual content of the message. If it's `image`, it's the file path. +Example: +```json +{ + "id": "0", + "messages": [ + { + "role": "system", + "type": "text", + "content": "You are a Multimodal AI assistant." + }, + { + "role": "user", + "type": "image", + "content": "/path/to/cat.jpg" + }, + { + "role": "user", + "type": "image", + "content": "/path/to/dog.jpg" + }, + { + "role": "user", + "type": "text", + "content": "What's the difference between these two images?" + }, + { + "role": "bot", + "type": "text", + "content": "The two images in question both feature animals, albeit of different species. The first image depicts a dog, which is generally perceived as an animal that elicits positive emotional responses. The second image features a cat, which is also regarded as an animal that evokes a positive emotional response." + } + ] +} +``` diff --git a/pyproject.toml b/pyproject.toml index 18991d7..c8a6871 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "turbo-alignment" packages = [ { include = "turbo_alignment" }, ] -version = "0.1.0" +version = "0.0.2" description = "turbo-alignment repository" authors = ["T Mega Alignment Team " ] @@ -44,7 +44,6 @@ pydantic ="^2.7.0" timm ="^0.9.7" opencv-python = "^4.10.0.84" langchain-huggingface = "^0.0.3" -ragas = "^0.1.10" [tool.poetry.group.deepspeed.dependencies] accelerate = "0.27" diff --git a/tests/cli/test_dpo_train.py b/tests/cli/test_dpo_train.py index 08e1f76..23bd4af 100755 --- a/tests/cli/test_dpo_train.py +++ b/tests/cli/test_dpo_train.py @@ -12,7 +12,10 @@ @pytest.mark.parametrize( 'config_path', - [FIXTURES_PATH / 'configs/train/dpo/base.json'], + [ + FIXTURES_PATH / 'configs/train/dpo/base.json', + FIXTURES_PATH / 'configs/train/dpo/simpo.json', + ], ) def test_dpo_train(config_path: Path): result = runner.invoke( diff --git a/tests/cli/test_multimodal_inference.py b/tests/cli/test_multimodal_inference.py index 91d1b0e..60c9a05 100644 --- a/tests/cli/test_multimodal_inference.py +++ b/tests/cli/test_multimodal_inference.py @@ -1,26 +1,26 @@ -# from pathlib import Path +from pathlib import Path -# import pytest -# from typer.testing import CliRunner +import pytest +from typer.testing import CliRunner -# from tests.constants import FIXTURES_PATH -# from turbo_alignment.cli import app -# from turbo_alignment.settings.pipelines.inference.multimodal import ( -# MultimodalInferenceExperimentSettings, -# ) +from tests.constants import FIXTURES_PATH +from turbo_alignment.cli import app +from turbo_alignment.settings.pipelines.inference.multimodal import ( + MultimodalInferenceExperimentSettings, +) -# runner = CliRunner() +runner = CliRunner() -# @pytest.mark.parametrize( -# 'config_path', -# [ -# FIXTURES_PATH / 'configs/inference/multimodal/llama_llava_clip_pickle.json', -# ], -# ) -# def test_multimodal_inference_mlp_with_preprocessing(config_path: Path): -# result = runner.invoke( -# app, ['inference_multimodal', '--inference_settings_path', str(config_path)], catch_exceptions=False -# ) -# assert result.exit_code == 0 -# assert MultimodalInferenceExperimentSettings.parse_file(config_path).save_path.is_dir() +@pytest.mark.parametrize( + 'config_path', + [ + FIXTURES_PATH / 'configs/inference/multimodal/llama_llava_clip_pickle.json', + ], +) +def test_multimodal_inference_mlp_with_preprocessing(config_path: Path): + result = runner.invoke( + app, ['inference_multimodal', '--inference_settings_path', str(config_path)], catch_exceptions=False + ) + assert result.exit_code == 0 + assert MultimodalInferenceExperimentSettings.parse_file(config_path).save_path.is_dir() diff --git a/tests/cli/test_sft.py b/tests/cli/test_sft.py index f66941f..ab01abd 100755 --- a/tests/cli/test_sft.py +++ b/tests/cli/test_sft.py @@ -14,24 +14,10 @@ 'config_path', [ FIXTURES_PATH / 'configs/train/sft/base.json', + FIXTURES_PATH / 'configs/train/sft/sft_with_rm_metric.json', ], ) def test_sft_train(config_path: Path): result = runner.invoke(app, ['train_sft', '--experiment_settings_path', str(config_path)], catch_exceptions=False) assert result.exit_code == 0 assert SftTrainExperimentSettings.parse_file(config_path).log_path.is_dir() - - -@pytest.mark.parametrize( - 'config_path', - [ - FIXTURES_PATH / 'configs/train/sft/resume_from_checkpoint.json', - ], -) -def test_sft_from_checkpoint(config_path: Path): - result = runner.invoke( - app, - ['train_sft', '--experiment_settings_path', str(config_path)], - ) - assert result.exit_code == 0 - assert SftTrainExperimentSettings.parse_file(config_path).log_path.is_dir() diff --git a/tests/fixtures/configs/inference/rag/base.json b/tests/fixtures/configs/inference/rag/base.json index 2282b7b..dff8d56 100755 --- a/tests/fixtures/configs/inference/rag/base.json +++ b/tests/fixtures/configs/inference/rag/base.json @@ -41,10 +41,10 @@ "num_beams": 1, "max_new_tokens": 10, "repetition_penalty": 1.2, + "stop_strings": "", "do_sample": false }, "custom_settings": { - "generation_eos_token": "", "skip_special_tokens": false, "remove_prompt": false } diff --git a/tests/fixtures/configs/train/dpo/base.json b/tests/fixtures/configs/train/dpo/base.json index 8df1392..a06caad 100755 --- a/tests/fixtures/configs/train/dpo/base.json +++ b/tests/fixtures/configs/train/dpo/base.json @@ -56,10 +56,10 @@ "generator_transformers_settings": { "num_beams": 1, "do_sample": false, + "stop_strings": "", "max_new_tokens": 8 }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": false }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/dpo/simpo.json b/tests/fixtures/configs/train/dpo/simpo.json new file mode 100755 index 0000000..593463d --- /dev/null +++ b/tests/fixtures/configs/train/dpo/simpo.json @@ -0,0 +1,134 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/train_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "val_dataset_settings": { + "sources": [ + { + "name": "rm_preferences_test", + "records_path": "tests/fixtures/datasets/rm/val_preferences.jsonl", + "sample_rate": 1 + } + ], + "chat_settings":{ + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "max_tokens_count": 120 + }, + "add_labels": true, + "dataset_type": "pair_preferences" + }, + "model_settings": { + "model_path": "tests/fixtures/models/llama2_tiny", + "model_type": "causal", + "transformers_settings": {}, + "adapter_path": "tests/fixtures/models/llama2_tiny_fine_tuned_with_adapters/trainer", + "is_trainable": true + }, + "cherry_pick_settings": { + "generator_transformers_settings": { + "num_beams": 1, + "do_sample": false, + "stop_strings": "", + "max_new_tokens": 8 + }, + "custom_generation_settings": { + "skip_special_tokens": false + }, + "dataset_settings": { + "sources": [ + { + "name": "chat_test", + "records_path": "tests/fixtures/datasets/chat/train_chat.jsonl", + "num_samples": 2 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "dataset_type": "chat", + "max_tokens_count": 150, + "only_answer_loss": true + }, + "metric_settings": [ + { + "type": "length", + "parameters": {"need_average": [true]} + }, + { + "type": "kl", + "parameters": { + "need_average": [true], + "ref_logits_type": "sft" + } + } + ] + }, + "tokenizer_settings": {}, + "trainer_settings": { + "evaluation_strategy": "steps", + "per_device_train_batch_size": 2, + "per_device_eval_batch_size": 2, + "gradient_accumulation_steps": 2, + "eval_steps": 4, + "save_steps": 4, + "logging_steps": 1, + "learning_rate": 0.0003, + "num_train_epochs": 2, + "lr_scheduler_type": "cosine", + "warmup_steps": 2, + "fp16": false, + "bf16": false, + "optim": "adamw_torch", + "save_total_limit": 1, + "average_log_prob": true, + "loss_settings": { + "loss_type": "simpo" + }, + "sync_ref_settings": { + "sync_ref_model": false + }, + "use_ref_model": false, + "use_sft_model": true, + "no_cuda": true + }, + "wandb_settings": { + "project_name": "alignment", + "run_name": "dpo", + "entity": "turbo-alignment" + }, + "log_path": "test_dpo_llama_train_output" +} diff --git a/tests/fixtures/configs/train/kto/base.json b/tests/fixtures/configs/train/kto/base.json index 30d4192..28dc7a4 100755 --- a/tests/fixtures/configs/train/kto/base.json +++ b/tests/fixtures/configs/train/kto/base.json @@ -54,10 +54,10 @@ "generator_transformers_settings": { "num_beams": 1, "do_sample": false, + "stop_strings": "", "max_new_tokens": 8 }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": false }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json index f140504..0195034 100644 --- a/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_c_abs_clip_pickle.json @@ -31,7 +31,7 @@ "start_modality_token": "", "end_modality_token": "", "dataset_type": "multimodal", - "max_tokens_count": 2000, + "max_tokens_count": 300, "only_answer_loss": true, "truncate_top": false }, @@ -67,7 +67,7 @@ "start_modality_token": "", "end_modality_token": "", "dataset_type": "multimodal", - "max_tokens_count": 2000, + "max_tokens_count": 300, "only_answer_loss": true, "truncate_top": false }, @@ -144,12 +144,12 @@ "cherry_pick_settings": { "generator_transformers_settings": { "num_beams": 1, - "max_new_tokens": 128, + "max_new_tokens": 16, "repetition_penalty": 1.1, + "stop_strings": "", "do_sample": true }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": true }, "dataset_settings": { @@ -170,7 +170,7 @@ "suffix_template": "" }, "dataset_type": "multimodal", - "max_tokens_count": 2000, + "max_tokens_count": 300, "n_modality_embeddings": 225, "start_modality_token": "", "end_modality_token": "", diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json index 1b51ab2..98fbad8 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_base_clip.json @@ -146,10 +146,10 @@ "num_beams": 1, "max_new_tokens": 128, "repetition_penalty": 1.1, + "stop_strings": "", "do_sample": true }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": true }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json index 9f23e43..d7a46e7 100644 --- a/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json +++ b/tests/fixtures/configs/train/multimodal/llama_llava_clip_pickle.json @@ -146,10 +146,10 @@ "num_beams": 1, "max_new_tokens": 4, "repetition_penalty": 1.1, + "stop_strings": "", "do_sample": false }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": true }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/rag/base.json b/tests/fixtures/configs/train/rag/base.json index b003986..f11b4b8 100755 --- a/tests/fixtures/configs/train/rag/base.json +++ b/tests/fixtures/configs/train/rag/base.json @@ -87,10 +87,10 @@ "num_beams": 3, "max_new_tokens": 16, "repetition_penalty": 1.1, + "stop_strings": "", "do_sample": true }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": false }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/sft/base.json b/tests/fixtures/configs/train/sft/base.json index 8e0b4d4..b5dc9fd 100755 --- a/tests/fixtures/configs/train/sft/base.json +++ b/tests/fixtures/configs/train/sft/base.json @@ -69,10 +69,10 @@ "cherry_pick_settings": { "generator_transformers_settings": { "num_beams": 3, + "stop_strings": ["", ""], "max_new_tokens": 8 }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": false }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/sft/prompt_tuning.json b/tests/fixtures/configs/train/sft/prompt_tuning.json index 0a09dfc..c187176 100755 --- a/tests/fixtures/configs/train/sft/prompt_tuning.json +++ b/tests/fixtures/configs/train/sft/prompt_tuning.json @@ -63,10 +63,10 @@ "num_beams": 1, "max_new_tokens": 35, "repetition_penalty": 1.1, + "stop_strings": "", "do_sample": true }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": false }, "dataset_settings": { diff --git a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json index 2a4515c..c281971 100755 --- a/tests/fixtures/configs/train/sft/sft_with_rm_metric.json +++ b/tests/fixtures/configs/train/sft/sft_with_rm_metric.json @@ -70,10 +70,10 @@ "generator_transformers_settings": { "num_beams": 1, "num_return_sequences": 2, + "stop_strings": "", "max_new_tokens": 8 }, "custom_generation_settings": { - "generation_eos_token": "", "skip_special_tokens": false }, "dataset_settings": { diff --git a/tests/fixtures/datasets/multimodal/image_chat.jsonl b/tests/fixtures/datasets/multimodal/image_chat.jsonl index b621ea1..e53f0b5 100644 --- a/tests/fixtures/datasets/multimodal/image_chat.jsonl +++ b/tests/fixtures/datasets/multimodal/image_chat.jsonl @@ -1,5 +1,4 @@ {"id": "0", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_1.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "Sorry, I will not describe the scene."}]} {"id": "1", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_2.jpg"}, {"role": "user", "type": "text", "content": "What do you see on the image?"}, {"role": "bot", "type": "text", "content": "I see nothing."}, {"role": "user", "type": "text", "content": "What about this one?"}, {"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_1.jpg"}, {"role": "bot", "type": "text", "content": "Sorry..."}]} {"id": "2", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_3.jpg"}, {"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_4.jpg"}, {"role": "user", "type": "text", "content": "Please, describe these two photos."}, {"role": "bot", "type": "text", "content": "OK."}]} -{"id": "3", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_5.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "No."}]} -{"id": "4", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/does_not_exist.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "No."}]} \ No newline at end of file +{"id": "3", "messages": [{"role": "user", "type": "image", "content": "tests/fixtures/datasets/multimodal/images/img_5.jpg"}, {"role": "user", "type": "text", "content": "Describe the scene"}, {"role": "bot", "type": "text", "content": "No."}]} \ No newline at end of file diff --git a/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json b/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json index 0d0a47e..2522058 100644 --- a/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json +++ b/tests/fixtures/models/llama2_tiny_multimodal_clip_mlp/adapter/adapter_config.json @@ -22,8 +22,8 @@ "rank_pattern": {}, "revision": null, "target_modules": [ - "q_proj", - "k_proj" + "k_proj", + "q_proj" ], "task_type": "CAUSAL_LM", "use_rslora": false diff --git a/turbo_alignment/cli/common.py b/turbo_alignment/cli/common.py index 9150e86..d823b42 100644 --- a/turbo_alignment/cli/common.py +++ b/turbo_alignment/cli/common.py @@ -4,7 +4,7 @@ from turbo_alignment import pipelines from turbo_alignment.cli.app import app -from turbo_alignment.common.tf.convert_to_base_model import peft_to_base_model +from turbo_alignment.common.tf.merge_adapters_to_base import peft_to_base_model from turbo_alignment.settings.datasets.multimodal import ( MultimodalDatasetProcessingSettings, ) diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py index 60afa6e..fc4d9b4 100644 --- a/turbo_alignment/common/data/multimodal/common.py +++ b/turbo_alignment/common/data/multimodal/common.py @@ -25,5 +25,4 @@ def read(self, path: str) -> torch.Tensor: safetensors_file = self._get_safetensors_file(Path(path).parent) if self.processed_tensors is None: self.processed_tensors = safe_open(safetensors_file, framework='pt', device='cpu') - return self.processed_tensors.get_tensor(Path(path).name) diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index 8a572bd..73736b7 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -29,7 +29,7 @@ def _load_pretrained_adapters( ) -> PeftModel: return PeftModel.from_pretrained( model, - model_settings.adapter_path, # type: ignore + model_settings.adapter_path, is_trainable=model_settings.is_trainable, ) @@ -60,7 +60,16 @@ def load_model( for new_token, old_token in model_settings.embeddings_initialization_strategy.items(): new_token_id = tokenizer.get_added_vocab()[new_token] old_token_id = tokenizer.encode(old_token, add_special_tokens=False)[0] - model.model.embed_tokens.weight[new_token_id, :] = model.model.embed_tokens.weight[old_token_id, :] + + if model.config.model_type == 'gpt_neox': + model.gpt_neox.embed_in.weight[new_token_id, :] = torch.clone( + model.gpt_neox.embed_in.weight[old_token_id, :] + ) + if model_settings.model_type == 'causal': + model.embed_out.weight[new_token_id, :] = torch.clone(model.embed_out.weight[old_token_id, :]) + + elif model.config.model_type == 'llama': + model.model.embed_tokens.weight[new_token_id, :] = model.model.embed_tokens.weight[old_token_id, :] if isinstance(model_settings, PreTrainedAdaptersModelSettings): model = _load_pretrained_adapters(model, model_settings) diff --git a/turbo_alignment/common/tf/convert_to_base_model.py b/turbo_alignment/common/tf/merge_adapters_to_base.py similarity index 100% rename from turbo_alignment/common/tf/convert_to_base_model.py rename to turbo_alignment/common/tf/merge_adapters_to_base.py diff --git a/turbo_alignment/common/tf/stopping_criteria.py b/turbo_alignment/common/tf/stopping_criteria.py deleted file mode 100755 index df8314c..0000000 --- a/turbo_alignment/common/tf/stopping_criteria.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -from transformers import PreTrainedTokenizerBase, StoppingCriteria - - -class EndTagCriteria(StoppingCriteria): - def __init__(self, end_tag: str, tokenizer: PreTrainedTokenizerBase) -> None: - super().__init__() - self.tokenizer = tokenizer - self.stop_tag = end_tag - - def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: - assert input_ids.shape[0] == 1 - return self.tokenizer.decode(input_ids[0], skip_special_tokens=False)[-len(self.stop_tag) :] == self.stop_tag diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index fb13633..54fcde2 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -5,7 +5,7 @@ import numpy as np import torch from allenai_common import Params -from safetensors import SafetensorError # type: ignore[attr-defined] +from safetensors._safetensors_rust import SafetensorError from turbo_alignment.common.data.io import read_jsonl from turbo_alignment.common.data.multimodal import BaseModalityReader @@ -152,9 +152,7 @@ def __init__(self, tokenizer, source, settings) -> None: super().__init__(tokenizer=tokenizer, source=source, settings=settings) - self._chat_dataset = TrainChatDataset( - tokenizer=tokenizer, source=source, settings=settings, read=False - ) # type: ignore[misc] + self._chat_dataset = TrainChatDataset(tokenizer=tokenizer, source=source, settings=settings, read=False) self._read() @@ -192,6 +190,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s ) tokenized_record['labels'][modality_tokens_mask] = DISABLE_LOSS_LABEL + outputs.append( { **tokenized_record, @@ -212,8 +211,11 @@ def __init__( **kwargs, ) -> None: super().__init__(*args, **kwargs) - - self._chat_dataset = InferenceChatDataset(*args, random_cut=random_cut, **kwargs, read=False) # type: ignore + settings = kwargs['settings'] + settings.random_cut = random_cut + self._chat_dataset = InferenceChatDataset( + tokenizer=kwargs['tokenizer'], source=kwargs['source'], settings=settings, read=False + ) self._read() def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]: diff --git a/turbo_alignment/generators/base.py b/turbo_alignment/generators/base.py index 244d369..8a9d36e 100755 --- a/turbo_alignment/generators/base.py +++ b/turbo_alignment/generators/base.py @@ -4,14 +4,8 @@ import torch from accelerate import Accelerator -from transformers import ( - GenerationConfig, - PreTrainedModel, - PreTrainedTokenizerBase, - StoppingCriteriaList, -) - -from turbo_alignment.common.tf.stopping_criteria import EndTagCriteria +from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase + from turbo_alignment.dataset.base import BaseDataset from turbo_alignment.dataset.base.models import DatasetRecord from turbo_alignment.settings.generators.chat import CustomChatGenerationSettings @@ -105,27 +99,8 @@ def __init__( self._return_logits = return_logits - eos_token_id: list[int] = self._tokenizer.encode( - custom_generation_settings.generation_eos_token, add_special_tokens=False - ) - - self._stopping_criteria: StoppingCriteriaList | None = None - if len(eos_token_id) != 1: - eos_token_id = [] - if transformers_settings.num_beams > 1 or transformers_settings.do_sample: - raise ValueError('You should use only 1 eos token with num_beams > 1 or do_sample=True') - self._stopping_criteria = StoppingCriteriaList( - [ - EndTagCriteria( - custom_generation_settings.generation_eos_token, - tokenizer=tokenizer, - ) - ] - ) - self._transformers_generator_parameters = GenerationConfig( bos_token_id=self._tokenizer.bos_token_id, - eos_token_id=eos_token_id, **transformers_settings.dict(), ) @@ -170,8 +145,8 @@ def _generate_from_batch( @staticmethod def _postprocess(input_indices: torch.Tensor, output_indices: torch.Tensor, remove_prompt: bool) -> torch.Tensor: if remove_prompt: - return output_indices[:, input_indices.shape[1] :] - return output_indices + return output_indices[:, input_indices.shape[1] :].cpu() + return output_indices.cpu() def _decode(self, token_indices: torch.Tensor) -> list[str]: return self._tokenizer.batch_decode( diff --git a/turbo_alignment/generators/chat.py b/turbo_alignment/generators/chat.py index 79798b1..acc4337 100755 --- a/turbo_alignment/generators/chat.py +++ b/turbo_alignment/generators/chat.py @@ -39,8 +39,8 @@ def _generate_from_batch_records( inputs=batched_input_ids, attention_mask=batched_attention_mask, generation_config=self._transformers_generator_parameters, + tokenizer=self._tokenizer, pad_token_id=self._tokenizer.pad_token_id, - stopping_criteria=self._stopping_criteria, ) postprocessed_output_indices = self._postprocess( @@ -84,8 +84,8 @@ def _generate_from_single_record( inputs=input_ids, attention_mask=attention_mask, generation_config=self._transformers_generator_parameters, + tokenizer=self._tokenizer, pad_token_id=self._tokenizer.pad_token_id, - stopping_criteria=self._stopping_criteria, ) postprocessed_output_indices = self._postprocess( @@ -102,18 +102,23 @@ def _generate_from_single_record( if self._return_logits: with torch.no_grad(): - logits = self._model(output_indices).logits + logits = self._model(output_indices).logits.cpu() answer_tokens_ids = postprocessed_output_indices input_token_ids = input_ids - return ChatInferenceOutput( - id=original_record.id, - dataset_name=dataset_name, - messages=original_record.messages, - label=original_record.label, - meta=original_record.meta, - answers=[ + answer_messages = [ + AnswerMessage( + id=str(i), + content=a, + input_token_ids=input_token_ids, + answer_token_ids=a_t_ids.unsqueeze(0), + logits=l.unsqueeze(0), + ) + for i, (a, a_t_ids, l) in enumerate(zip(answers, answer_tokens_ids, logits)) # type: ignore[arg-type] + ] + else: + answer_messages = [ AnswerMessage( id=str(i), content=a, @@ -122,5 +127,13 @@ def _generate_from_single_record( logits=logits, ) for i, a in enumerate(answers) - ], + ] + + return ChatInferenceOutput( + id=original_record.id, + dataset_name=dataset_name, + messages=original_record.messages, + label=original_record.label, + meta=original_record.meta, + answers=answer_messages, ) diff --git a/turbo_alignment/generators/multimodal.py b/turbo_alignment/generators/multimodal.py index ece60e5..0c64137 100755 --- a/turbo_alignment/generators/multimodal.py +++ b/turbo_alignment/generators/multimodal.py @@ -50,6 +50,7 @@ def _generate_from_single_record( output_indices = self._model.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, + tokenizer=self._tokenizer, generation_config=self._transformers_generator_parameters, ) diff --git a/turbo_alignment/generators/rag.py b/turbo_alignment/generators/rag.py index 3658b36..4203de8 100755 --- a/turbo_alignment/generators/rag.py +++ b/turbo_alignment/generators/rag.py @@ -22,12 +22,12 @@ def _generate_from_single_record( answer_indices, document_indices, doc_scores = self._model.generate( inputs=input_ids, generation_config=self._transformers_generator_parameters, + tokenizer=self._tokenizer.current_tokenizer, pad_token_id=self._tokenizer.pad_token_id, - stopping_criteria=self._stopping_criteria, ) - answers = self._decode(token_indices=answer_indices) - documents = self._decode(token_indices=document_indices) + answers = self._decode(token_indices=answer_indices.cpu()) + documents = self._decode(token_indices=document_indices.cpu()) doc_scores = list(doc_scores[0]) return RagInferenceOutput( diff --git a/turbo_alignment/generators/rm.py b/turbo_alignment/generators/rm.py index ab3e07e..ce9a9c8 100755 --- a/turbo_alignment/generators/rm.py +++ b/turbo_alignment/generators/rm.py @@ -28,7 +28,7 @@ def _generate_from_batch( attn_mask = batch['attention_mask'].to(self.device) with torch.no_grad(): - rewards = self._model(input_ids=input_ids, attention_mask=attn_mask).logits + rewards = self._model(input_ids=input_ids, attention_mask=attn_mask).logits.cpu() rewards_w, rewards_l = rewards[: len(records)], rewards[len(records) :] return [ @@ -74,7 +74,7 @@ def _generate_from_batch( for i in range(0, len(input_ids), self._micro_batch): input_ids_batch = input_ids[i : i + self._micro_batch].to(self.device) attn_mask_batch = attn_mask[i : i + self._micro_batch].to(self.device) - rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits) + rewards.extend(self._model(input_ids=input_ids_batch, attention_mask=attn_mask_batch).logits.cpu()) rewards = torch.cat(rewards, dim=0) diff --git a/turbo_alignment/generators/vllm_chat.py b/turbo_alignment/generators/vllm_chat.py index 03235c0..513c5e7 100755 --- a/turbo_alignment/generators/vllm_chat.py +++ b/turbo_alignment/generators/vllm_chat.py @@ -25,11 +25,15 @@ def __init__( model.set_tokenizer(tokenizer) super().__init__(model, tokenizer, batch=batch) - eos_token_id: list[int] = self._tokenizer.encode( - custom_generation_settings.generation_eos_token, add_special_tokens=False - ) + if isinstance(transformers_settings.stop_strings, list): + raise ValueError('You should use only 1 eos token with VLLM') + + eos_token_id: list[int] = self._tokenizer.encode(transformers_settings.stop_strings, add_special_tokens=False) - beam_search_params = {'best_of': transformers_settings.num_return_sequences, 'use_beam_search': False} + beam_search_params: dict[str, Any] = { + 'best_of': transformers_settings.num_return_sequences, + 'use_beam_search': False, + } if transformers_settings.num_beams > 1: beam_search_params['use_beam_search'] = True beam_search_params['best_of'] = transformers_settings.num_beams @@ -43,7 +47,7 @@ def __init__( skip_special_tokens=custom_generation_settings.skip_special_tokens, stop_token_ids=eos_token_id, max_tokens=transformers_settings.max_new_tokens, - **beam_search_params, # type: ignore[arg-type] + **beam_search_params, ) def _generate_from_batch( diff --git a/turbo_alignment/metrics/__init__.py b/turbo_alignment/metrics/__init__.py index 8d1c27d..c70bcea 100755 --- a/turbo_alignment/metrics/__init__.py +++ b/turbo_alignment/metrics/__init__.py @@ -5,7 +5,6 @@ from turbo_alignment.metrics.meteor import MeteorMetric from turbo_alignment.metrics.metric import Metric from turbo_alignment.metrics.perplexity import PerplexityMetric -from turbo_alignment.metrics.ragas import RagasMetrics from turbo_alignment.metrics.registry import * from turbo_alignment.metrics.retrieval_utility import RetrievalUtilityMetric from turbo_alignment.metrics.reward import RewardMetric diff --git a/turbo_alignment/metrics/ragas.py b/turbo_alignment/metrics/ragas.py deleted file mode 100644 index f73183b..0000000 --- a/turbo_alignment/metrics/ragas.py +++ /dev/null @@ -1,79 +0,0 @@ -# pylint: skip-file -# pylint: disable-all -# mypy: ignore-errors - -from datasets import Dataset -from ragas import RunConfig, evaluate -from ragas.metrics import ( - answer_relevancy, - answer_similarity, - context_entity_recall, - context_precision, - context_recall, - faithfulness, -) - -from turbo_alignment.dataset.chat import InferenceChatDataset -from turbo_alignment.metrics.metric import Metric -from turbo_alignment.metrics.registry import RagasMetricsSettings -from turbo_alignment.settings.generators.outputs.chat import RagInferenceOutput -from turbo_alignment.settings.metric import MetricResults, MetricType - - -@Metric.register(MetricType.RAGAS_METRICS) -class RagasMetrics(Metric): - def __init__(self, settings: RagasMetricsSettings) -> None: - self._settings: RagasMetricsSettings = settings - - if self._settings.openai_api_key is not None: - # use openai endpoints if api key is provided - from langchain_openai import OpenAI, OpenAIEmbeddings - - self._llm = OpenAI(openai_api_key=self._settings.openai_api_key, model='gpt-3.5-turbo-instruct') - self._embeddings = OpenAIEmbeddings( - openai_api_key=self._settings.openai_api_key, model='text-embedding-3-large' - ) - - elif self._settings.mistralai_api_key is not None: - from langchain_mistralai import MistralAIEmbeddings - from langchain_mistralai.chat_models import ChatMistralAI - - self._llm = ChatMistralAI(name='mistral-large', api_key=self._settings.mistralai_api_key) - - self._embeddings = MistralAIEmbeddings(api_key=self._settings.mistralai_api_key) - - def compute( - self, dataset: InferenceChatDataset, generations: list[RagInferenceOutput], **kwargs - ) -> list[MetricResults]: - questions = [d['messages'][0].content for d in dataset] - ground_truths = [d['messages'][1].content for d in dataset] - retieved_docs = [g.documents for g in generations] - answers = [g.answers[0].content for g in generations] - - ragas_dataset = Dataset.from_dict( - {'question': questions, 'ground_truth': ground_truths, 'contexts': retieved_docs, 'answer': answers} - ) - - extra_kwargs = {} - if self._llm: - extra_kwargs['llm'] = self._llm - if self._embeddings: - extra_kwargs['embeddings'] = self._embeddings - - results = evaluate( - ragas_dataset, - metrics=[ - faithfulness, - answer_relevancy, - answer_similarity, - context_precision, - context_recall, - context_entity_recall, - ], - **extra_kwargs, - raise_exceptions=False, - is_async=True, - run_config=RunConfig(max_workers=1, max_wait=180, thread_timeout=600), - ) - - return results diff --git a/turbo_alignment/metrics/registry.py b/turbo_alignment/metrics/registry.py index b3ebf2f..3c65eb4 100755 --- a/turbo_alignment/metrics/registry.py +++ b/turbo_alignment/metrics/registry.py @@ -1,7 +1,6 @@ from enum import Enum from allenai_common import Registrable -from pydantic import model_validator from turbo_alignment.settings.base import ExtraFieldsNotAllowedBaseModel from turbo_alignment.settings.datasets.chat import ChatPromptTemplate @@ -87,19 +86,3 @@ class ToolMetricsSettings(MetricSettings): @MetricSettingsRegistry.register(MetricType.RETRIEVAL_UTILITY) class RetrievalUtilitySettings(MetricSettings): doc_sep_symbol: str = '' - - -@MetricSettingsRegistry.register(MetricType.RAGAS_METRICS) -class RagasMetricsSettings(MetricSettings): - openai_api_key: str | None = None - mistralai_api_key: str | None = None - - @model_validator(mode='before') - def check_only_one_field(cls, values): - openai_api_key = values.get('openai_api_key') - mistralai_api_key = values.get('mistralai_api_key') - - if not bool(openai_api_key) and not bool(mistralai_api_key): - raise ValueError('At least one of openai_api_key or mistralai_api_key must be specified') - - return values diff --git a/turbo_alignment/metrics/reward.py b/turbo_alignment/metrics/reward.py index 851e8ee..dee37a1 100755 --- a/turbo_alignment/metrics/reward.py +++ b/turbo_alignment/metrics/reward.py @@ -47,7 +47,7 @@ def compute(self, **kwargs) -> list[MetricResults]: messages = [record['messages'] for record in dataset.records for _ in range(answers_per_context)] answers = [ - AnswerMessage(id=ans_idx, content=ans) + AnswerMessage(id=str(ans_idx), content=ans) for ctx_answers in predictions for ans_idx, ans in enumerate(ctx_answers) ] diff --git a/turbo_alignment/modeling/common/transformer.py b/turbo_alignment/modeling/common/transformer.py index 3582c64..eebc380 100755 --- a/turbo_alignment/modeling/common/transformer.py +++ b/turbo_alignment/modeling/common/transformer.py @@ -77,7 +77,16 @@ class MultiheadAttention(torch.nn.MultiheadAttention): # pylint: disable=arguments-differ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: # type: ignore[override] - return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + return super().forward( + query=x, + key=x, + value=x, + key_padding_mask=None, + need_weights=False, + attn_mask=attn_mask, + average_attn_weights=True, + is_causal=False, + )[0] class BlockWithMasking(torch.nn.Module): diff --git a/turbo_alignment/modeling/imagebind/heads/registry.py b/turbo_alignment/modeling/imagebind/heads/registry.py index e1f1388..9dbd9d7 100755 --- a/turbo_alignment/modeling/imagebind/heads/registry.py +++ b/turbo_alignment/modeling/imagebind/heads/registry.py @@ -1,3 +1,5 @@ +from abc import ABC + import torch from turbo_alignment.modeling.common.helpers import SelectElement, SelectEOSAndProject @@ -7,7 +9,7 @@ ) -class Heads(torch.nn.ModuleDict): # pylint: disable=abstract-method +class Heads(ABC, torch.nn.ModuleDict): @staticmethod def __get_vision_head( vision_embed_dim: int, diff --git a/turbo_alignment/modeling/imagebind/imagebind.py b/turbo_alignment/modeling/imagebind/imagebind.py index 269dd02..e039f96 100755 --- a/turbo_alignment/modeling/imagebind/imagebind.py +++ b/turbo_alignment/modeling/imagebind/imagebind.py @@ -1,4 +1,3 @@ -# pylint: skip-file # pylint: disable=unused-import import torch diff --git a/turbo_alignment/modeling/imagebind/postprocessors/registry.py b/turbo_alignment/modeling/imagebind/postprocessors/registry.py index d0a7bcd..b3d631d 100755 --- a/turbo_alignment/modeling/imagebind/postprocessors/registry.py +++ b/turbo_alignment/modeling/imagebind/postprocessors/registry.py @@ -1,3 +1,5 @@ +from abc import ABC + import torch from turbo_alignment.modeling.common.helpers import LearnableLogitScaling, Normalize @@ -7,7 +9,7 @@ ) -class Postprocessors(torch.nn.ModuleDict): # pylint: disable=abstract-method +class Postprocessors(ABC, torch.nn.ModuleDict): def __init__(self, _settings: ImageBindArchitectureSettings): super().__init__() diff --git a/turbo_alignment/modeling/imagebind/preprocessors/impl.py b/turbo_alignment/modeling/imagebind/preprocessors/impl.py index 25d3a4a..ced8476 100755 --- a/turbo_alignment/modeling/imagebind/preprocessors/impl.py +++ b/turbo_alignment/modeling/imagebind/preprocessors/impl.py @@ -177,12 +177,11 @@ def get_pos_embedding(self, vision_input, all_vision_tokens): class RGBDTPreprocessor(VerboseNNModule): - # pylint: disable=dangerous-default-value def __init__( self, rgbt_stem: PatchEmbedGeneric | None, depth_stem: PatchEmbedGeneric | None, - img_size: list = [3, 224, 224], + img_size: list | None = None, num_cls_tokens: int = 1, pos_embed_fn: Callable | None = None, use_type_embed: bool = False, @@ -190,6 +189,8 @@ def __init__( ) -> None: super().__init__() stem = rgbt_stem if rgbt_stem is not None else depth_stem + if img_size is None: + img_size = [3, 224, 224] assert stem is not None ( self.patches_layout, @@ -277,11 +278,10 @@ def forward(self, vision: torch.Tensor | None = None, depth: torch.Tensor | None class AudioPreprocessor(RGBDTPreprocessor): - # pylint: disable=arguments-differ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None: super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs) - def forward(self, audio: torch.Tensor | None = None) -> dict: # type: ignore [override] + def forward(self, *_args, audio: torch.Tensor | None = None, **_kwargs) -> dict: # vision here is actually audio return super().forward(vision=audio) diff --git a/turbo_alignment/modeling/imagebind/preprocessors/registry.py b/turbo_alignment/modeling/imagebind/preprocessors/registry.py index ae5e65d..d5e6043 100755 --- a/turbo_alignment/modeling/imagebind/preprocessors/registry.py +++ b/turbo_alignment/modeling/imagebind/preprocessors/registry.py @@ -1,3 +1,4 @@ +from abc import ABC from functools import partial import torch @@ -16,7 +17,7 @@ ) -class Preprocessors(torch.nn.ModuleDict): # pylint: disable=abstract-method +class Preprocessors(ABC, torch.nn.ModuleDict): @staticmethod def __get_rgbt_preprocessor( video_frames: int, diff --git a/turbo_alignment/modeling/imagebind/trunks/registry.py b/turbo_alignment/modeling/imagebind/trunks/registry.py index 5a5bf6b..f3e113f 100755 --- a/turbo_alignment/modeling/imagebind/trunks/registry.py +++ b/turbo_alignment/modeling/imagebind/trunks/registry.py @@ -1,3 +1,4 @@ +from abc import ABC from functools import partial import torch @@ -13,7 +14,7 @@ ) -class Trunks(torch.nn.ModuleDict): # pylint: disable=abstract-method +class Trunks(ABC, torch.nn.ModuleDict): @staticmethod def __instantiate_trunk( embed_dim: int, diff --git a/turbo_alignment/modeling/multimodal/encoders/image/clip.py b/turbo_alignment/modeling/multimodal/encoders/image/clip.py index 4d7b9bc..d803c17 100644 --- a/turbo_alignment/modeling/multimodal/encoders/image/clip.py +++ b/turbo_alignment/modeling/multimodal/encoders/image/clip.py @@ -25,8 +25,7 @@ def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, i def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor: if is_pickle: return inputs - # pylint: disable=line-too-long - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213 # noqa: E501 + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213 # -2 is default value of vision_feature_layer in llava config # [1:] is everything after vit [cls] token return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][ diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py index 21de1e0..69c5055 100644 --- a/turbo_alignment/modeling/multimodal/lm/projection.py +++ b/turbo_alignment/modeling/multimodal/lm/projection.py @@ -1,4 +1,5 @@ from collections import defaultdict +from pathlib import Path import torch from scipy.ndimage.measurements import find_objects, label @@ -40,14 +41,13 @@ def __init__(self, *args, **kwargs) -> None: Please, set n_modality_embs to {self.encoders[modality].n_modality_embs} in config.' if self.modality_projector_initialization_mapping: - if self.modality_projector_initialization_mapping.get(modality): + state_dict_path: Path | None = self.modality_projector_initialization_mapping.get(modality, None) + if state_dict_path is not None: logger.info(f'Loading {modality} connector weights') - state_dictionary = torch.load( - self.modality_projector_initialization_mapping[modality] # type: ignore[arg-type] - ) + state_dictionary = torch.load(state_dict_path) modality_adapters[modality].load_state_dict(state_dictionary) - logger.info(f'Sucsessfully loaded from {self.modality_projector_initialization_mapping[modality]}') + logger.info(f'Sucsessfully loaded from {state_dict_path}') self.modality_adapters = torch.nn.ModuleDict(modality_adapters) @@ -122,7 +122,6 @@ def forward( labels: torch.LongTensor | None = None, ) -> ModelOutput: multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask) - return self.language_model( inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask ) diff --git a/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py b/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py index ed4fbf0..7bf08df 100644 --- a/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py +++ b/turbo_alignment/modeling/multimodal/projectors/c_abstractor.py @@ -1,3 +1,4 @@ +from abc import ABC from functools import partial import torch @@ -14,6 +15,7 @@ class HoneybeeVisualProjectorConfig(BaseModel): projector_type: str = 'c-abs' + initializer_range: float = 1.0 depth: int = 3 mlp_depth: int = 2 hidden_size: int = 1024 @@ -42,7 +44,7 @@ def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size: num_eos_tokens = config.num_eos_tokens if num_eos_tokens: eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size)) - torch.nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) # type: ignore + torch.nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) else: eos_tokens = None @@ -58,9 +60,9 @@ def build_prenorm(config: HoneybeeVisualProjectorConfig): def build_mlp(depth: int, hidden_size: int, output_hidden_size: int): - layers = [torch.nn.Linear(hidden_size, output_hidden_size)] + layers: list[torch.nn.Module] = [torch.nn.Linear(hidden_size, output_hidden_size)] for _ in range(1, depth): - layers.append(torch.nn.SiLU()) # type: ignore + layers.append(torch.nn.SiLU()) layers.append(torch.nn.Linear(output_hidden_size, output_hidden_size)) return torch.nn.Sequential(*layers) @@ -131,8 +133,7 @@ def _load_from_state_dict(self, state_dict, *args, **kwargs): super()._load_from_state_dict(state_dict, *args, **kwargs) -# pylint: disable=abstract-method -class ConvProjector(Projector): +class ConvProjector(ABC, Projector): def _forward(self, x): # x: [B, L, dim] hw = int(x.size(1) ** 0.5) diff --git a/turbo_alignment/modeling/rag/rag_model.py b/turbo_alignment/modeling/rag/rag_model.py index 9a58c2a..9bd926b 100755 --- a/turbo_alignment/modeling/rag/rag_model.py +++ b/turbo_alignment/modeling/rag/rag_model.py @@ -296,7 +296,7 @@ def generate( self, inputs: torch.LongTensor, generation_config: GenerationConfig | None = None, - **kwargs, # pylint: disable=unused-argument + **kwargs, ) -> tuple[torch.LongTensor, torch.LongTensor, torch.Tensor]: # TODO remove code duplicate with forward @@ -343,7 +343,7 @@ def generate( input_ids=joined_input_ids, generation_config=generation_config, pad_token_id=self.tokenizer.pad_token_id, - stopping_criteria=kwargs['stopping_criteria'], + tokenizer=kwargs.get('tokenizer', None), ) # TODO chose max-prob sequence with accounting for doc probs only_answer_output = output_sequences[:, joined_input_ids.shape[-1] :] diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index c9c7ba8..f85a65d 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -47,7 +47,7 @@ def _read_modality_objects(self, reader: BaseModalityReader, dataset_path: Path) total_number_of_objects = len(list(dataset_path.iterdir())) logger.info('📖 Reading modality objects...') - files_paths = [] # type: ignore + files_paths: list[Path] = [] for extension in available_extensions: files_paths.extend(dataset_path.glob(f'*.{extension}')) @@ -121,6 +121,7 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: experiment_settings.output_file_path.mkdir(parents=True, exist_ok=True) tensors = self._get_safetensor_dict(encoded_modality_tensors, encoded_file_paths) + del encoded_modality_tensors save_file( diff --git a/turbo_alignment/pipelines/train/ddpo.py b/turbo_alignment/pipelines/train/ddpo.py index 8a35a38..a280b6f 100755 --- a/turbo_alignment/pipelines/train/ddpo.py +++ b/turbo_alignment/pipelines/train/ddpo.py @@ -94,7 +94,8 @@ def _get_trainer( data_collator: Callable, rm_model: PreTrainedModel = None, ) -> DDPOTrainer: - model.config.use_cache = False + model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + extra_args = {'rm': rm_model} return DDPOTrainer( diff --git a/turbo_alignment/pipelines/train/dpo.py b/turbo_alignment/pipelines/train/dpo.py index 7290edd..155f6d3 100755 --- a/turbo_alignment/pipelines/train/dpo.py +++ b/turbo_alignment/pipelines/train/dpo.py @@ -73,6 +73,7 @@ def _get_trainer( data_collator: Callable, ): model.config.use_cache = not training_args.gradient_checkpointing + extra_args = {} if experiment_settings.trainer_settings.use_ref_model: ref_model = load_model(experiment_settings.model_settings, tokenizer) diff --git a/turbo_alignment/pipelines/train/kto.py b/turbo_alignment/pipelines/train/kto.py index 20c8c5f..b34d813 100755 --- a/turbo_alignment/pipelines/train/kto.py +++ b/turbo_alignment/pipelines/train/kto.py @@ -72,6 +72,8 @@ def _get_trainer( val_dataset: Dataset, data_collator: Callable, ): + model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + extra_args = {} if experiment_settings.trainer_settings.use_ref_model: ref_model = load_model(experiment_settings.model_settings, tokenizer) diff --git a/turbo_alignment/pipelines/train/sft.py b/turbo_alignment/pipelines/train/sft.py index 85b5346..a1bddec 100755 --- a/turbo_alignment/pipelines/train/sft.py +++ b/turbo_alignment/pipelines/train/sft.py @@ -72,6 +72,8 @@ def _get_trainer( data_collator: DataCollatorMixin, **_kwargs, ) -> MultiGPUCherryPicksTrainer: + model.config.use_cache = not experiment_settings.trainer_settings.gradient_checkpointing + return MultiGPUCherryPicksTrainer( model=model, args=training_args, diff --git a/turbo_alignment/settings/generators/chat.py b/turbo_alignment/settings/generators/chat.py index 3869baf..01bcc09 100755 --- a/turbo_alignment/settings/generators/chat.py +++ b/turbo_alignment/settings/generators/chat.py @@ -3,6 +3,5 @@ class CustomChatGenerationSettings(ExtraFieldsNotAllowedBaseModel): skip_special_tokens: bool = True - generation_eos_token: str = '' remove_prompt: bool = True batch: int = 1 diff --git a/turbo_alignment/settings/metric.py b/turbo_alignment/settings/metric.py index ee71e44..a0cecac 100755 --- a/turbo_alignment/settings/metric.py +++ b/turbo_alignment/settings/metric.py @@ -19,8 +19,6 @@ class MetricType(str, Enum): KL: str = 'kl' TOOL_CALL_METRICS: str = 'tool_call_metrics' RETRIEVAL_UTILITY: str = 'retrieval_utility' - INTENT_CLASSIFIER_ACCURACY: str = 'intent_classifier_accuracy' - RAGAS_METRICS: str = 'ragas_metrics' class ElementWiseScores(ExtraFieldsNotAllowedBaseModel): diff --git a/turbo_alignment/settings/pipelines/train/base.py b/turbo_alignment/settings/pipelines/train/base.py index e9f8227..7ed9d66 100755 --- a/turbo_alignment/settings/pipelines/train/base.py +++ b/turbo_alignment/settings/pipelines/train/base.py @@ -25,12 +25,10 @@ class BaseTrainExperimentSettings(BaseSettings): log_path: Path = Path('train_output') seed: int = 42 - # early_stopping: EarlyStoppingSettings | None = None - trainer_settings: TrainerSettings tokenizer_settings: TokenizerSettings - model_settings: ModelForPeftSettings | PreTrainedModelSettings | PreTrainedAdaptersModelSettings + model_settings: (ModelForPeftSettings | PreTrainedModelSettings | PreTrainedAdaptersModelSettings) train_dataset_settings: MultiDatasetSettings val_dataset_settings: MultiDatasetSettings diff --git a/turbo_alignment/settings/pipelines/train/dpo.py b/turbo_alignment/settings/pipelines/train/dpo.py index c7a0ae0..64d0a43 100755 --- a/turbo_alignment/settings/pipelines/train/dpo.py +++ b/turbo_alignment/settings/pipelines/train/dpo.py @@ -19,6 +19,8 @@ class DPOLossesType(str, Enum): KTO = 'kto' SLIC_HF = 'slic_hf' CPO = 'cpo' + ORPO = 'orpo' + SIMPO = 'simpo' class DPOLossSettings(ExtraFieldsNotAllowedBaseModel): @@ -57,6 +59,17 @@ class SlicHfLossSettings(DPOLossSettings): norm: bool = False +class SimPOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.SIMPO] + beta: float = 0.1 + gamma: float = 0.1 + + +class ORPOLossSettings(DPOLossSettings): + loss_type: Literal[DPOLossesType.ORPO] + beta: float = 0.1 + + class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): sync_ref_model: bool = False alpha: float = 1.0 @@ -64,19 +77,22 @@ class SyncRefModelSettings(ExtraFieldsNotAllowedBaseModel): class DPOTrainerSettings(TrainerSettings): - loss_settings: SigmoidLossSettings | HingeLossSettings | IPOLossSettings | KTOLossSettings | CPOLossSettings + loss_settings: ( + SigmoidLossSettings + | HingeLossSettings + | IPOLossSettings + | KTOLossSettings + | CPOLossSettings + | ORPOLossSettings + | SimPOLossSettings + | SlicHfLossSettings + ) sync_ref_settings: SyncRefModelSettings use_ref_model: bool = True use_sft_model: bool = False average_log_prob: bool = Field(default=False, description='Normalize log probability by length or not') -class SlicHfTrainerSettings(TrainerSettings): - loss_settings: SlicHfLossSettings - use_ref_model: bool = True - average_log_prob: bool = Field(default=False, description='Normalize log probability by length or not') - - class DPOTrainExperimentSettings(BaseTrainExperimentSettings): train_dataset_settings: PairPreferenceMultiDatasetSettings val_dataset_settings: PairPreferenceMultiDatasetSettings diff --git a/turbo_alignment/settings/tf/generation.py b/turbo_alignment/settings/tf/generation.py index ee977ff..14537d9 100755 --- a/turbo_alignment/settings/tf/generation.py +++ b/turbo_alignment/settings/tf/generation.py @@ -10,3 +10,4 @@ class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel): top_p: float = 1.0 top_k: int = 50 temperature: float = 1.0 + stop_strings: str | list[str] = '' diff --git a/turbo_alignment/settings/tf/peft.py b/turbo_alignment/settings/tf/peft.py index b826760..f479a27 100755 --- a/turbo_alignment/settings/tf/peft.py +++ b/turbo_alignment/settings/tf/peft.py @@ -11,7 +11,7 @@ class BasePeftSettings(ExtraFieldsNotAllowedBaseModel): class LoraSettings(BasePeftSettings): - name: Literal[PeftType.LORA] = PeftType.LORA # type: ignore[valid-type] + name: Literal[PeftType.LORA] = PeftType.LORA r: int = 16 lora_alpha: int = 16 lora_dropout: float = 0.05 @@ -21,19 +21,19 @@ class LoraSettings(BasePeftSettings): class PrefixTuningSettings(BasePeftSettings): - name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING # type: ignore[valid-type] + name: Literal[PeftType.PREFIX_TUNING] = PeftType.PREFIX_TUNING encoder_hidden_size: int prefix_projection: bool class PromptTuningSettings(BasePeftSettings): - name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING # type: ignore[valid-type] + name: Literal[PeftType.PROMPT_TUNING] = PeftType.PROMPT_TUNING num_virtual_tokens: int = 32 prompt_tuning_init_text: str | None = None class PTuningSettings(BasePeftSettings): - name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING # type: ignore[valid-type] + name: Literal[PeftType.P_TUNING] = PeftType.P_TUNING num_virtual_tokens: int = 32 encoder_reparameterization_type: str = 'MLP' diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 8d91070..ff17c91 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -1,4 +1,3 @@ -# mypy: disable-error-code="call-overload" from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Callable, Literal @@ -25,10 +24,15 @@ from turbo_alignment.common.tf.callbacks.sync_ref_model import SyncRefModelCallback from turbo_alignment.constants import DISABLE_LOSS_LABEL from turbo_alignment.settings.pipelines.train.dpo import ( + CPOLossSettings, DPOLossesType, HingeLossSettings, IPOLossSettings, + KTOLossSettings, + ORPOLossSettings, SigmoidLossSettings, + SimPOLossSettings, + SlicHfLossSettings, SyncRefModelSettings, ) from turbo_alignment.trainers.utils import ( @@ -198,6 +202,7 @@ def compute_loss( return loss, chosen_rewards, rejected_rewards +@DPOLossRegistry.register(DPOLossesType.SLIC_HF) class SlicHfLoss(DPOLossRegistry): def __init__(self, delta: float = 1, beta: float = 1.0, lam: float = 1.0, norm: bool = False) -> None: self.delta = delta @@ -232,12 +237,83 @@ def compute_loss( return loss, chosen_rewards, rejected_rewards +@DPOLossRegistry.register(DPOLossesType.SIMPO) +class SimPOLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, gamma: float = 0.1, **kwargs) -> None: + self.beta = beta + self.gamma = gamma + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor | None, + reference_rejected_logps: torch.FloatTensor | None, + policy_best_decode_logps: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + + logits = pi_logratios - self.gamma + + chosen_rewards = self.beta * (policy_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps).detach() + + loss = -F.logsigmoid(self.beta * logits) + + return ( + loss, + chosen_rewards, + rejected_rewards, + ) + + +@DPOLossRegistry.register(DPOLossesType.ORPO) +class ORPOLoss(DPOLossRegistry): + def __init__(self, *args, beta: float = 0.1, **kwargs) -> None: + self.beta = beta + super().__init__(*args, **kwargs) + + def compute_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor | None, + reference_rejected_logps: torch.FloatTensor | None, + policy_best_decode_logps: torch.FloatTensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7)) + - torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7)) + ) + + ratio = -F.logsigmoid(log_odds) + losses = self.beta * ratio + + chosen_rewards = self.beta * policy_chosen_logps.detach() + rejected_rewards = self.beta * policy_rejected_logps.detach() + + return losses, chosen_rewards, rejected_rewards + + @dataclass class DPOTrainingArguments(TrainingArguments): - loss_settings: SigmoidLossSettings | HingeLossSettings | IPOLossSettings | SlicHfLoss | KTOLoss | CPOLoss = field( + loss_settings: ( + SigmoidLossSettings + | HingeLossSettings + | IPOLossSettings + | SlicHfLossSettings + | KTOLossSettings + | CPOLossSettings + | ORPOLossSettings + | SimPOLossSettings + | SlicHfLossSettings + ) = field( default_factory=SigmoidLossSettings(loss_type=DPOLossesType.SIGMOID) + ) # type: ignore[call-overload] + sync_ref_settings: SyncRefModelSettings = field( # type: ignore[call-overload] + default_factory=SyncRefModelSettings() ) - sync_ref_settings: SyncRefModelSettings = field(default_factory=SyncRefModelSettings()) use_ref_model: bool = True use_sft_model: bool = False average_log_prob: bool = False @@ -269,6 +345,10 @@ def __init__( if hasattr(args, 'loss_settings'): self.loss_type = args.loss_settings['loss_type'] # type: ignore[index] + + if self.loss_type in (DPOLossesType.SIMPO, DPOLossesType.ORPO) and not args.average_log_prob: + raise ValueError(f'You should normalize logits by length when using {self.loss_type}') + loss_args = args.loss_settings loss_args.pop('loss_type') # type: ignore[union-attr] self.dpo_loss_registry = DPOLossRegistry.by_name(self.loss_type)(**loss_args) @@ -286,6 +366,9 @@ def __init__( **kwargs, ) + if hasattr(args, 'loss_settings') and self.loss_type in (DPOLossesType.SIMPO, DPOLossesType.ORPO): + logger.info(f'You can turn off ref_model when using {self.loss_type} for memory saving') + self.ref_model = ref_model self.sft_model = sft_model @@ -407,7 +490,10 @@ def get_batch_metrics( policy_best_decode_logps, ) = self.concatenated_forward(model, batch) - reference_chosen_logps, reference_rejected_logps = self._get_logps(self.ref_model, batch) + reference_chosen_logps, reference_rejected_logps = torch.Tensor([float('inf')]), torch.Tensor([float('inf')]) + + if self.args.use_ref_model or self.loss_type not in (DPOLossesType.SIMPO, DPOLossesType.ORPO): + reference_chosen_logps, reference_rejected_logps = self._get_logps(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps=policy_chosen_logps, @@ -423,14 +509,16 @@ def get_batch_metrics( metrics = self._compute_metrics(metrics, dpo_prefix_name, chosen_rewards, rejected_rewards) - metrics[f'{prefix}logps/ref_rejected'] = (reference_rejected_logps).detach().cpu().mean().item() - metrics[f'{prefix}logps/ref_chosen'] = (reference_chosen_logps).detach().cpu().mean().item() metrics[f'{prefix}logps/rejected'] = (policy_rejected_logps).detach().cpu().mean().item() metrics[f'{prefix}logps/chosen'] = (policy_chosen_logps).detach().cpu().mean().item() metrics[f'{prefix}logits/rejected'] = (policy_rejected_logits).detach().cpu().mean().item() metrics[f'{prefix}logits/chosen'] = (policy_chosen_logits).detach().cpu().mean().item() + if self.args.use_ref_model: + metrics[f'{prefix}logps/ref_rejected'] = (reference_rejected_logps).detach().cpu().mean().item() + metrics[f'{prefix}logps/ref_chosen'] = (reference_chosen_logps).detach().cpu().mean().item() + if self.loss_type == DPOLossesType.KTO: kto_chosen_KL = ( (policy_chosen_logps.detach().cpu() - reference_chosen_logps.detach().cpu()).mean().clamp(min=0) @@ -448,6 +536,22 @@ def get_batch_metrics( metrics[f'{prefix}rewards/kto_grad_term_chosen'] = kto_grad_term_chosen.item() metrics[f'{prefix}rewards/kto_grad_term_rejected'] = kto_grad_term_rejected.item() + elif self.loss_type == DPOLossesType.ORPO: + labels_w = batch['inputs_w']['labels'][:, 1:].clone() + loss_mask_w = labels_w != DISABLE_LOSS_LABEL + length_norm_policy_chosen_logps = policy_chosen_logps / loss_mask_w.sum(-1) + + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.clamp(torch.exp(policy_chosen_logps), max=1 - 1e-7)) + - torch.log1p(-torch.clamp(torch.exp(policy_rejected_logps), max=1 - 1e-7)) + ) + ratio = -F.logsigmoid(log_odds) + nll_loss = -length_norm_policy_chosen_logps + + metrics[f'{prefix}orpo/nll_loss'] = nll_loss.clone().detach().cpu().mean().item() + metrics[f'{prefix}orpo/ratio'] = (ratio).detach().cpu().mean().item() + metrics[f'{prefix}orpo/log_odds'] = (log_odds).detach().cpu().mean().item() + if self.sft_model is not None: sft_chosen_logps, sft_rejected_logps = self._get_logps(self.sft_model, batch) @@ -463,7 +567,7 @@ def get_batch_metrics( sft_prefix_name = prefix + 'rewards/sft_' metrics = self._compute_metrics(metrics, sft_prefix_name, sft_chosen_rewards, sft_rejected_rewards) - return losses.mean(), metrics # type: ignore + return losses.mean(), metrics def _compute_metrics( self, metrics: dict[str, float], prefix_name: str, chosen_rewards: torch.Tensor, rejected_rewards: torch.Tensor @@ -523,7 +627,7 @@ def prediction_step( 'logits_test/rejected': metrics['logits_test/rejected'], } logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) - logits = torch.stack(logits).mean(axis=1) # type: ignore[arg-type, call-overload] + logits = torch.stack(logits).mean(axis=1) # type: ignore[call-overload, arg-type] labels = torch.zeros(logits.shape[0]) return loss.detach(), logits, labels diff --git a/turbo_alignment/trainers/kto.py b/turbo_alignment/trainers/kto.py index 29adceb..e3e7c6e 100755 --- a/turbo_alignment/trainers/kto.py +++ b/turbo_alignment/trainers/kto.py @@ -31,8 +31,8 @@ class KTOTrainingArguments(TrainingArguments): beta: float = 0.1 sync_ref_settings: SyncRefModelSettings = field( - default_factory=SyncRefModelSettings() # type: ignore[call-overload] - ) + default_factory=SyncRefModelSettings() + ) # type: ignore[call-overload] use_ref_model: bool = True average_log_prob: bool = False undesirable_weight: float = 1.0 diff --git a/turbo_alignment/trainers/multigpu.py b/turbo_alignment/trainers/multigpu.py index 3a75e39..7c84076 100755 --- a/turbo_alignment/trainers/multigpu.py +++ b/turbo_alignment/trainers/multigpu.py @@ -1,7 +1,5 @@ -import os from typing import Callable -import numpy as np from torch import nn from torch.utils.data import Dataset from transformers import ( @@ -16,7 +14,6 @@ TrainingArguments, ) from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from turbo_alignment.common.tf.callbacks.common import WandbMetricsCallbackHandler @@ -61,56 +58,3 @@ def __init__( ) self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback) self.control: TrainerControl = self.callback_handler.on_init_end(self.args, self.state, self.control) - - def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unused-argument - # FIX: https://github.com/huggingface/transformers/issues/28119 - # https://github.com/huggingface/transformers/pull/29370 - - # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we - # want to save except FullyShardedDDP. - # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" - - # Save model checkpoint - checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}' - - if self.hp_search_backend is None and trial is None: - self.store_flos() - - run_dir = self._get_output_dir(trial=trial) - output_dir = os.path.join(run_dir, checkpoint_folder) - self.save_model(output_dir, _internal_call=True) - - if not self.args.save_only_model: - # Save optimizer and scheduler - self._save_optimizer_and_scheduler(output_dir) - # Save RNG state - self._save_rng_state(output_dir) - - # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith('eval_'): - metric_to_check = f'eval_{metric_to_check}' - metric_value = metrics[metric_to_check] - - operator = np.greater if self.args.greater_is_better else np.less - if ( - self.state.best_metric is None - or self.state.best_model_checkpoint is None - or operator(metric_value, self.state.best_metric) - ): - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir - - # Save the Trainer state - if self.args.should_save: - self.state.save_to_json(os.path.join(output_dir, 'trainer_state.json')) - - if self.args.push_to_hub: - self._push_from_checkpoint(output_dir) - - # Maybe delete some older checkpoints. - if self.args.should_save: - # Solely rely on numerical checkpoint id for rotation. - # mtime is not reliable especially on some fuse fs in cloud environments. - self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) diff --git a/turbo_alignment/trainers/multimodal.py b/turbo_alignment/trainers/multimodal.py index ea87f21..59189dc 100755 --- a/turbo_alignment/trainers/multimodal.py +++ b/turbo_alignment/trainers/multimodal.py @@ -26,7 +26,7 @@ class TrainerCustomSave(MultiGPUCherryPicksTrainer): - def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unused-argument + def _save_checkpoint(self, model, trial, metrics=None): logger.info('Running custom _save_checkpoint') checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}' run_dir = self._get_output_dir(trial=trial) @@ -53,9 +53,15 @@ def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unuse (output_dir / 'adapter').mkdir(parents=True, exist_ok=True) (output_dir / 'tokenizer').mkdir(parents=True, exist_ok=True) - torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt') + if isinstance(model, torch.nn.DataParallel): + torch.save( + model.module.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt' + ) + model.module.language_model.save_pretrained(output_dir / 'adapter') + else: + torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt') + model.language_model.save_pretrained(output_dir / 'adapter') - model.language_model.save_pretrained(output_dir / 'adapter') self.tokenizer.save_pretrained(output_dir / 'tokenizer') @@ -269,7 +275,7 @@ def prediction_step( 'logits_test/rejected': metrics['logits_test/rejected'], } logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) - logits = torch.stack(logits).mean(axis=1) # type: ignore[arg-type, call-overload] + logits = torch.stack(logits).mean(axis=1) # type: ignore[call-overload, arg-type] labels = torch.zeros(logits.shape[0]) return loss.detach(), logits, labels @@ -285,7 +291,7 @@ def log(self, logs: Dict[str, float]) -> None: del self._stored_metrics[train_eval] return super().log(logs) # pylint: disable=no-member - def _save_checkpoint(self, model, trial, metrics=None): # pylint: disable=unused-argument + def _save_checkpoint(self, model, trial, metrics=None): logger.info('Running custom _save_checkpoint') checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}' run_dir = self._get_output_dir(trial=trial) diff --git a/turbo_alignment/trainers/utils.py b/turbo_alignment/trainers/utils.py index ebd054a..8021d0a 100755 --- a/turbo_alignment/trainers/utils.py +++ b/turbo_alignment/trainers/utils.py @@ -54,10 +54,10 @@ def prepare_model_for_deepspeed( config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) if model is not None: if hasattr(model, 'config'): - hidden_size: int | None = ( # type: ignore - max(model.config.hidden_sizes) # type: ignore - if getattr(model.config, 'hidden_sizes', None) # type: ignore - else getattr(model.config, 'hidden_size', None) # type: ignore + hidden_size: int | None = ( + max(model.config.hidden_sizes) + if getattr(model.config, 'hidden_sizes', None) + else getattr(model.config, 'hidden_size', None) ) if hidden_size is not None and config_kwargs['zero_optimization']['stage'] == 3: diff --git a/tutorials/multimodal/create_tutorial_dataset.py b/tutorials/multimodal/create_tutorial_dataset.py new file mode 100644 index 0000000..ee7582d --- /dev/null +++ b/tutorials/multimodal/create_tutorial_dataset.py @@ -0,0 +1,43 @@ +import json +import random +import subprocess +from pathlib import Path +from typing import Any + +from datasets import load_dataset + +from turbo_alignment.common.data.io import write_jsonl +from turbo_alignment.dataset.chat.models import ChatMessageRole +from turbo_alignment.dataset.multimodal.models import ( + MultimodalChatMessage, + MultimodalDatasetRecord, + MultimodalImageMessage, + MultimodalTextMessage, +) +from turbo_alignment.settings.modality import Modality + + +def convert_to_multimodal_record(row): + return MultimodalDatasetRecord( + id=row['id'], + messages=[ + MultimodalImageMessage(role=ChatMessageRole.USER, content=f"images/00000/{str(row['id']).zfill(9)}.jpg"), + MultimodalTextMessage(role=ChatMessageRole.BOT, content=row['image_descriptions'][0].strip()), + ], + ).dict() + + +if __name__ == '__main__': + dataset = load_dataset('passing2961/photochat_plus')['train'] + dataset = dataset.add_column('id', range(len(dataset))) + dataset = dataset.train_test_split(test_size=0.1) + + dataset['train_multimodal_records'] = dataset['train'].map( + convert_to_multimodal_record, remove_columns=dataset['train'].column_names + ) + dataset['val_multimodal_records'] = dataset['test'].map( + convert_to_multimodal_record, remove_columns=dataset['test'].column_names + ) + + write_jsonl([item for item in dataset['train_multimodal_records']], Path('train_multimodal.jsonl')) + write_jsonl([item for item in dataset['val_multimodal_records']], Path('val_multimodal.jsonl')) diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json new file mode 100644 index 0000000..6910ab3 --- /dev/null +++ b/tutorials/multimodal/multimodal.json @@ -0,0 +1,196 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "train", + "records_path": "train_chat.jsonl", + "num_samples": 10 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "modality_token_mapping": { + "image": "", + "audio": "