diff --git a/docs/dataset_example.md b/docs/dataset_example.md
index 7ca2f81..6ea91c5 100644
--- a/docs/dataset_example.md
+++ b/docs/dataset_example.md
@@ -11,7 +11,7 @@
- [Pair Preferences Dataset](#-pair-preferences-dataset)
- [KTO Dataset](#-kto-dataset)
- [Sampling Dataset](#-sampling-dataset)
-- [Multimodal Dataset ](#-multimodal-dataset) (⌛️ Work in progress...)
+- [Multimodal Dataset ](#-multimodal-dataset)
- [Classification Dataset](#-classification-dataset)
- [DPPO Dataset](#-ddpo-dataset) (⌛️ Work in progress...)
@@ -118,9 +118,45 @@ Example:
## Multimodal Dataset
-⌛️ in progress..
+- `messages`: `list[MultimodalChatMessage]` — This is a sequence of messages that make up the chat history. Each `ChatMessage` includes:
+ - `role` - The participant's role in the conversation (e.g., `user` or `bot`).
+ - `type` – The type of modality (e.g., `text` or `image`)
+ - `content` - If the `type` is `text`, it's the textual content of the message. If it's `image`, it's the file path.
+Example:
+```json
+{
+ "id": "0",
+ "messages": [
+ {
+ "role": "system",
+ "type": "text",
+ "content": "You are a Multimodal AI assistant."
+ },
+ {
+ "role": "user",
+ "type": "image",
+ "content": "/path/to/cat.jpg"
+ },
+ {
+ "role": "user",
+ "type": "image",
+ "content": "/path/to/dog.jpg"
+ },
+ {
+ "role": "user",
+ "type": "text",
+ "content": "What's the difference between these two images?"
+ },
+ {
+ "role": "bot",
+ "type": "text",
+ "content": "The two images in question both feature animals, albeit of different species. The first image depicts a dog, which is generally perceived as an animal that elicits positive emotional responses. The second image features a cat, which is also regarded as an animal that evokes a positive emotional response."
+ }
+ ]
+}
+```
diff --git a/tests/fixtures/datasets/multimodal/images/image.clip.safetensors b/tests/fixtures/datasets/multimodal/images/image.clip.safetensors
index bb4d44b..af7685f 100644
Binary files a/tests/fixtures/datasets/multimodal/images/image.clip.safetensors and b/tests/fixtures/datasets/multimodal/images/image.clip.safetensors differ
diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py
index 60afa6e..fc4d9b4 100644
--- a/turbo_alignment/common/data/multimodal/common.py
+++ b/turbo_alignment/common/data/multimodal/common.py
@@ -25,5 +25,4 @@ def read(self, path: str) -> torch.Tensor:
safetensors_file = self._get_safetensors_file(Path(path).parent)
if self.processed_tensors is None:
self.processed_tensors = safe_open(safetensors_file, framework='pt', device='cpu')
-
return self.processed_tensors.get_tensor(Path(path).name)
diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py
index 0f19c97..54fcde2 100644
--- a/turbo_alignment/dataset/multimodal/multimodal.py
+++ b/turbo_alignment/dataset/multimodal/multimodal.py
@@ -190,6 +190,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
)
tokenized_record['labels'][modality_tokens_mask] = DISABLE_LOSS_LABEL
+
outputs.append(
{
**tokenized_record,
diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py
index 90062e6..69c5055 100644
--- a/turbo_alignment/modeling/multimodal/lm/projection.py
+++ b/turbo_alignment/modeling/multimodal/lm/projection.py
@@ -122,7 +122,6 @@ def forward(
labels: torch.LongTensor | None = None,
) -> ModelOutput:
multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask)
-
return self.language_model(
inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask
)
diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py
index 5ade1cf..f85a65d 100644
--- a/turbo_alignment/pipelines/preprocessing/multimodal.py
+++ b/turbo_alignment/pipelines/preprocessing/multimodal.py
@@ -121,6 +121,7 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
experiment_settings.output_file_path.mkdir(parents=True, exist_ok=True)
tensors = self._get_safetensor_dict(encoded_modality_tensors, encoded_file_paths)
+
del encoded_modality_tensors
save_file(
diff --git a/turbo_alignment/trainers/multimodal.py b/turbo_alignment/trainers/multimodal.py
index 0441365..59189dc 100755
--- a/turbo_alignment/trainers/multimodal.py
+++ b/turbo_alignment/trainers/multimodal.py
@@ -53,9 +53,15 @@ def _save_checkpoint(self, model, trial, metrics=None):
(output_dir / 'adapter').mkdir(parents=True, exist_ok=True)
(output_dir / 'tokenizer').mkdir(parents=True, exist_ok=True)
- torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt')
+ if isinstance(model, torch.nn.DataParallel):
+ torch.save(
+ model.module.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt'
+ )
+ model.module.language_model.save_pretrained(output_dir / 'adapter')
+ else:
+ torch.save(model.modality_adapters.state_dict(), output_dir / 'projections' / 'modality_adapters.pt')
+ model.language_model.save_pretrained(output_dir / 'adapter')
- model.language_model.save_pretrained(output_dir / 'adapter')
self.tokenizer.save_pretrained(output_dir / 'tokenizer')
diff --git a/tutorials/multimodal/create_tutorial_dataset.py b/tutorials/multimodal/create_tutorial_dataset.py
new file mode 100644
index 0000000..ee7582d
--- /dev/null
+++ b/tutorials/multimodal/create_tutorial_dataset.py
@@ -0,0 +1,43 @@
+import json
+import random
+import subprocess
+from pathlib import Path
+from typing import Any
+
+from datasets import load_dataset
+
+from turbo_alignment.common.data.io import write_jsonl
+from turbo_alignment.dataset.chat.models import ChatMessageRole
+from turbo_alignment.dataset.multimodal.models import (
+ MultimodalChatMessage,
+ MultimodalDatasetRecord,
+ MultimodalImageMessage,
+ MultimodalTextMessage,
+)
+from turbo_alignment.settings.modality import Modality
+
+
+def convert_to_multimodal_record(row):
+ return MultimodalDatasetRecord(
+ id=row['id'],
+ messages=[
+ MultimodalImageMessage(role=ChatMessageRole.USER, content=f"images/00000/{str(row['id']).zfill(9)}.jpg"),
+ MultimodalTextMessage(role=ChatMessageRole.BOT, content=row['image_descriptions'][0].strip()),
+ ],
+ ).dict()
+
+
+if __name__ == '__main__':
+ dataset = load_dataset('passing2961/photochat_plus')['train']
+ dataset = dataset.add_column('id', range(len(dataset)))
+ dataset = dataset.train_test_split(test_size=0.1)
+
+ dataset['train_multimodal_records'] = dataset['train'].map(
+ convert_to_multimodal_record, remove_columns=dataset['train'].column_names
+ )
+ dataset['val_multimodal_records'] = dataset['test'].map(
+ convert_to_multimodal_record, remove_columns=dataset['test'].column_names
+ )
+
+ write_jsonl([item for item in dataset['train_multimodal_records']], Path('train_multimodal.jsonl'))
+ write_jsonl([item for item in dataset['val_multimodal_records']], Path('val_multimodal.jsonl'))
diff --git a/tutorials/multimodal/multimodal.json b/tutorials/multimodal/multimodal.json
new file mode 100644
index 0000000..6910ab3
--- /dev/null
+++ b/tutorials/multimodal/multimodal.json
@@ -0,0 +1,196 @@
+{
+ "train_dataset_settings": {
+ "sources": [
+ {
+ "name": "train",
+ "records_path": "train_chat.jsonl",
+ "num_samples": 10
+ }
+ ],
+ "prompt_template": {
+ "role_tag_mapping": {
+ "bot": "",
+ "user": "",
+ "system": ""
+ },
+ "prefix_template": "{role}",
+ "suffix_template": ""
+ },
+ "modality_token_mapping": {
+ "image": "",
+ "audio": "