diff --git a/docs/dataset_example.md b/docs/dataset_example.md index 7ca2f81..6ea91c5 100644 --- a/docs/dataset_example.md +++ b/docs/dataset_example.md @@ -11,7 +11,7 @@ - [Pair Preferences Dataset](#-pair-preferences-dataset) - [KTO Dataset](#-kto-dataset) - [Sampling Dataset](#-sampling-dataset) -- [Multimodal Dataset ](#-multimodal-dataset) (⌛️ Work in progress...) +- [Multimodal Dataset ](#-multimodal-dataset) - [Classification Dataset](#-classification-dataset) - [DPPO Dataset](#-ddpo-dataset) (⌛️ Work in progress...) @@ -118,9 +118,45 @@ Example: ## Multimodal Dataset -⌛️ in progress.. +- `messages`: `list[MultimodalChatMessage]` — This is a sequence of messages that make up the chat history. Each `ChatMessage` includes: + - `role` - The participant's role in the conversation (e.g., `user` or `bot`). + - `type` – The type of modality (e.g., `text` or `image`) + - `content` - If the `type` is `text`, it's the textual content of the message. If it's `image`, it's the file path. +Example: +```json +{ + "id": "0", + "messages": [ + { + "role": "system", + "type": "text", + "content": "You are a Multimodal AI assistant." + }, + { + "role": "user", + "type": "image", + "content": "/path/to/cat.jpg" + }, + { + "role": "user", + "type": "image", + "content": "/path/to/dog.jpg" + }, + { + "role": "user", + "type": "text", + "content": "What's the difference between these two images?" + }, + { + "role": "bot", + "type": "text", + "content": "The two images in question both feature animals, albeit of different species. The first image depicts a dog, which is generally perceived as an animal that elicits positive emotional responses. The second image features a cat, which is also regarded as an animal that evokes a positive emotional response." + } + ] +} +``` diff --git a/tests/fixtures/datasets/multimodal/images/image.clip.safetensors b/tests/fixtures/datasets/multimodal/images/image.clip.safetensors index bb4d44b..af7685f 100644 Binary files a/tests/fixtures/datasets/multimodal/images/image.clip.safetensors and b/tests/fixtures/datasets/multimodal/images/image.clip.safetensors differ diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py index 60afa6e..fc4d9b4 100644 --- a/turbo_alignment/common/data/multimodal/common.py +++ b/turbo_alignment/common/data/multimodal/common.py @@ -25,5 +25,4 @@ def read(self, path: str) -> torch.Tensor: safetensors_file = self._get_safetensors_file(Path(path).parent) if self.processed_tensors is None: self.processed_tensors = safe_open(safetensors_file, framework='pt', device='cpu') - return self.processed_tensors.get_tensor(Path(path).name) diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index 0f19c97..54fcde2 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -190,6 +190,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s ) tokenized_record['labels'][modality_tokens_mask] = DISABLE_LOSS_LABEL + outputs.append( { **tokenized_record, diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py index 90062e6..69c5055 100644 --- a/turbo_alignment/modeling/multimodal/lm/projection.py +++ b/turbo_alignment/modeling/multimodal/lm/projection.py @@ -122,7 +122,6 @@ def forward( labels: torch.LongTensor | None = None, ) -> ModelOutput: multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask) - return self.language_model( inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask ) diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 5ade1cf..f85a65d 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -121,6 +121,7 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: experiment_settings.output_file_path.mkdir(parents=True, exist_ok=True) tensors = self._get_safetensor_dict(encoded_modality_tensors, encoded_file_paths) + del encoded_modality_tensors save_file( diff --git a/turbo_alignment/trainers/multimodal.py b/turbo_alignment/trainers/multimodal.py index 0441365..59189dc 100755 --- a/turbo_alignment/trainers/multimodal.py +++ b/turbo_alignment/trainers/multimodal.py @@ -53,9 +53,15 @@ def _save_checkpoint(self, model, trial, metrics=None): (output_dir / 'adapter').mkdir(parents=True, exist_ok=True) (output_dir / 'tokenizer').mkdir(parents=True, exist_ok=True) - torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt') + if isinstance(model, torch.nn.DataParallel): + torch.save( + model.module.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt' + ) + model.module.language_model.save_pretrained(output_dir / 'adapter') + else: + torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt') + model.language_model.save_pretrained(output_dir / 'adapter') - model.language_model.save_pretrained(output_dir / 'adapter') self.tokenizer.save_pretrained(output_dir / 'tokenizer') diff --git a/tutorials/multimodal/create_tutorial_dataset.py b/tutorials/multimodal/create_tutorial_dataset.py new file mode 100644 index 0000000..ee7582d --- /dev/null +++ b/tutorials/multimodal/create_tutorial_dataset.py @@ -0,0 +1,43 @@ +import json +import random +import subprocess +from pathlib import Path +from typing import Any + +from datasets import load_dataset + +from turbo_alignment.common.data.io import write_jsonl +from turbo_alignment.dataset.chat.models import ChatMessageRole +from turbo_alignment.dataset.multimodal.models import ( + MultimodalChatMessage, + MultimodalDatasetRecord, + MultimodalImageMessage, + MultimodalTextMessage, +) +from turbo_alignment.settings.modality import Modality + + +def convert_to_multimodal_record(row): + return MultimodalDatasetRecord( + id=row['id'], + messages=[ + MultimodalImageMessage(role=ChatMessageRole.USER, content=f"images/00000/{str(row['id']).zfill(9)}.jpg"), + MultimodalTextMessage(role=ChatMessageRole.BOT, content=row['image_descriptions'][0].strip()), + ], + ).dict() + + +if __name__ == '__main__': + dataset = load_dataset('passing2961/photochat_plus')['train'] + dataset = dataset.add_column('id', range(len(dataset))) + dataset = dataset.train_test_split(test_size=0.1) + + dataset['train_multimodal_records'] = dataset['train'].map( + convert_to_multimodal_record, remove_columns=dataset['train'].column_names + ) + dataset['val_multimodal_records'] = dataset['test'].map( + convert_to_multimodal_record, remove_columns=dataset['test'].column_names + ) + + write_jsonl([item for item in dataset['train_multimodal_records']], Path('train_multimodal.jsonl')) + write_jsonl([item for item in dataset['val_multimodal_records']], Path('val_multimodal.jsonl')) diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json new file mode 100644 index 0000000..6910ab3 --- /dev/null +++ b/tutorials/multimodal/multimodal.json @@ -0,0 +1,196 @@ +{ + "train_dataset_settings": { + "sources": [ + { + "name": "train", + "records_path": "train_chat.jsonl", + "num_samples": 10 + } + ], + "prompt_template": { + "role_tag_mapping": { + "bot": "", + "user": "", + "system": "" + }, + "prefix_template": "{role}", + "suffix_template": "" + }, + "modality_token_mapping": { + "image": "", + "audio": "