Skip to content

Commit

Permalink
IterableDataset for multimodal pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 12, 2024
1 parent 79af8d9 commit b63c973
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 63 deletions.
1 change: 1 addition & 0 deletions turbo_alignment/common/data/multimodal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _get_pt_files(path: Path) -> Path:
return list(path.glob('*.pt'))

def read(self, path: str) -> torch.Tensor:
print('Calling a read!')
if self.processed_batches is None:
self.processed_batches = {}
pt_files = self._get_pt_files(Path(path).parent)
Expand Down
86 changes: 84 additions & 2 deletions turbo_alignment/dataset/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Generic, TypeVar, overload

import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset
from transformers import PreTrainedTokenizerBase

from turbo_alignment.common.logging import get_project_logger
Expand Down Expand Up @@ -101,9 +101,91 @@ def _read_records(records: list[dict]) -> list[RecordT]:
@abstractmethod
def _read_records(records):
...

class BaseIterableDataset(IterableDataset, ABC, Generic[RecordT]):
def __init__(
self,
source: DatasetSourceSettings,
settings: BaseDatasetSettings,
) -> None:
self.source = source
self.settings = settings

self.original_records_map: dict[str, RecordT] = {}
self.records: list[dict[str, torch.Tensor]] = []

def _read(self) -> None:
if self.source.records_data:
records = self._read_records(self.source.records_data)
elif self.source.records_path:
records = self._read_records(self.source.records_path)
else:
raise ValueError('At least one of records_data and records_path should be not None')

if self.source.offset is not None and self.source.n_rows is not None:
records = records[self.source.offset : self.source.offset + self.source.n_rows]

self.original_records_map, self.records = self._sample_dataset(records)

logger.info(f'Sampled {len(self.records)} records with offset {self.source.offset}')

def _sample_dataset(
self,
original_records: list[RecordT],
) -> tuple[dict[str, RecordT], list[dict[str, Any]]]:
if self.source.sample_rate is not None:
logger.info(f'Sampling dataset {self.source.name} with sample rate: {self.source.sample_rate}')
sampled_original_records = {
record.id: record for record in original_records if random.random() <= self.source.sample_rate
}
elif self.source.num_samples is not None:
logger.info(f'Sampling {self.source.num_samples} from dataset {self.source.name}')
sampled_original_records = {
record.id: record
for record in random.sample(original_records, k=min(self.source.num_samples, len(original_records)))
}
else:
raise ValueError('neither sample_rate nor num_samples are not set')

sampled_records = [r for r in self.convert_records(list(sampled_original_records.values())) if r is not None]

return sampled_original_records, sampled_records

def __len__(self) -> int:
return len(self.records)

def __getitem__(self, index: int) -> dict[str, Any]:
return self.records[index]

def __iter__(self):
return iter(self.records)

def get_original_record_by_id(self, record_id: str) -> RecordT:
return self.original_records_map[record_id]

@abstractmethod
def convert_records(self, records: list[RecordT]) -> list[dict[str, Any] | None]:
...

@staticmethod
@abstractmethod
@overload
def _read_records(records: Path) -> list[RecordT]:
...

@staticmethod
@abstractmethod
@overload
def _read_records(records: list[dict]) -> list[RecordT]:
...

@staticmethod
@abstractmethod
def _read_records(records):
...


class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]):
class AlignmentDataset(BaseIterableDataset, ABC, Generic[RecordT]):
def __init__(
self,
source: DatasetSourceSettings,
Expand Down
20 changes: 17 additions & 3 deletions turbo_alignment/dataset/multimodal/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,20 @@ def torch_call(self, features):
label_name = 'label' if 'label' in features[0].keys() else 'labels'
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

modality_inputs = [feature['modality_inputs'] for feature in features]

modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask')
if 'modality_inputs' in features[0].keys():
modality_inputs = [feature['modality_inputs'] for feature in features]
else:
modality_inputs = [None for feature in features]

if 'messages' in features[0].keys():
message_inputs = [feature['messages'] for feature in features]
else:
message_inputs = [None] * len(features)

if 'messages' in features[0].keys():
modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask', 'messages')
else:
modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask')
tokenizer_features = [
{k: v for k, v in feature.items() if k not in modality_input_names} for feature in features
]
Expand All @@ -29,6 +40,9 @@ def torch_call(self, features):
assert padding_side == 'right'

batch['modality_inputs'] = modality_inputs

if 'messages' in features[0].keys():
batch['messages'] = message_inputs

batch['modality_tokens_mask'] = torch.stack(
[
Expand Down
116 changes: 68 additions & 48 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
import math
from pathlib import Path
from typing import Any, overload

Expand Down Expand Up @@ -116,26 +117,6 @@ def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord

return ChatDatasetRecord(id=record.id, messages=converted_messages)

def _read_modalities(
self, record: MultimodalDatasetRecord, modality_messages_after_truncation: int
) -> list[tuple[Modality, torch.Tensor]]:
modality_messages: list[MultimodalFileMessage] = [
m for m in record.messages if isinstance(m, MultimodalFileMessage)
]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
return modality_encodings


@MultimodalDatasetTypeRegistry.register(DatasetStrategy.TRAIN)
class TrainMultimodalDataset(MultimodalDataset):
Expand Down Expand Up @@ -169,19 +150,15 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
outputs.append(None)
continue

modality_messages_after_truncation = int(
(tokenized_record['input_ids'] == self._get_token_id(self._start_modality_token)).sum()
)
# try:
# encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
# except (OSError, RuntimeError, KeyError):
# outputs.append(None)
# continue

try:
encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
except (OSError, RuntimeError, KeyError):
outputs.append(None)
continue

if len(encoded_modalities) != modality_messages_after_truncation:
outputs.append(None)
continue
# if len(encoded_modalities) != modality_messages_after_truncation:
# outputs.append(None)
# continue

modality_tokens_mask = torch.isin(
tokenized_record['input_ids'],
Expand All @@ -193,13 +170,49 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
outputs.append(
{
**tokenized_record,
'modality_inputs': encoded_modalities,
# 'modality_inputs': encoded_modalities,
'messages': record.messages,
'modality_tokens_mask': modality_tokens_mask,
}
)

return outputs

def _read_modalities(self, record):
modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum())

modality_messages: list[MultimodalFileMessage] = [
m for m in record['messages'] if isinstance(m, MultimodalFileMessage)
]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
record['modality_inputs'] = modality_encodings
return record

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()

start = 0
end = len(self.records) - 1
if worker_info:
per_worker = int(math.ceil(len(self.records) / float(worker_info.num_workers)))
worker_id = worker_info.id
start = start + worker_id * per_worker
end = min(start + per_worker, end)
print("🩻"*10, f"{float(worker_info.num_workers)=}, {per_worker=}, {worker_info.id=}, {start=}, {end=}")

return map(self._read_modalities, iter(self.records[start:end]))


@MultimodalDatasetTypeRegistry.register(DatasetStrategy.INFERENCE)
class InferenceMultimodalDataset(MultimodalDataset):
Expand All @@ -216,6 +229,27 @@ def __init__(
tokenizer=kwargs['tokenizer'], source=kwargs['source'], settings=settings, read=False
)
self._read()

def _read_modalities(self, record):
modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum())

modality_messages: list[MultimodalFileMessage] = [
m for m in record['messages'] if isinstance(m, MultimodalFileMessage)
]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
record['modality_inputs'] = modality_encodings
return record

def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]:
chat_records = [self._convert_to_chat(r) for r in records]
Expand All @@ -228,20 +262,6 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
outputs.append(None)
continue

modality_messages_after_truncation = int(
(tokenized_record['input_ids'] == self._get_token_id(self._start_modality_token)).sum()
)

try:
encoded_modalities = self._read_modalities(record, modality_messages_after_truncation)
except (OSError, RuntimeError, KeyError):
outputs.append(None)
continue

if len(encoded_modalities) != modality_messages_after_truncation:
outputs.append(None)
continue

modality_tokens_mask = torch.isin(
tokenized_record['input_ids'],
torch.tensor([self._get_token_id(token) for token in self._modality_token_mapping.values()]),
Expand All @@ -255,7 +275,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s
outputs.append(
{
**tokenized_record,
'modality_inputs': encoded_modalities,
'messages': record.messages,
'modality_tokens_mask': modality_tokens_mask,
'modality_object_paths': modality_object_paths,
}
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

import numpy as np

import torch
from accelerate import Accelerator
import os
Expand Down
34 changes: 24 additions & 10 deletions turbo_alignment/pipelines/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,36 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None:
special_tokens_setter.setup_model_config(self.model)

logger.info('Model is loaded!')

train_dataset: ConcatDataset = ConcatDataset(
datasets=DatasetLoader().load_datasets(


# train_dataset: ConcatDataset = ConcatDataset(
# datasets=DatasetLoader().load_datasets(
# experiment_settings.train_dataset_settings,
# tokenizer=self.tokenizer,
# strategy=DatasetStrategy.TRAIN,
# )
# )

# val_dataset: ConcatDataset = ConcatDataset(
# datasets=DatasetLoader().load_datasets(
# experiment_settings.val_dataset_settings,
# tokenizer=self.tokenizer,
# strategy=DatasetStrategy.TRAIN,
# )
# )

train_dataset = datasets=DatasetLoader().load_datasets(
experiment_settings.train_dataset_settings,
tokenizer=self.tokenizer,
strategy=DatasetStrategy.TRAIN,
)
)

val_dataset: ConcatDataset = ConcatDataset(
datasets=DatasetLoader().load_datasets(
)[0]

val_dataset = datasets=DatasetLoader().load_datasets(
experiment_settings.val_dataset_settings,
tokenizer=self.tokenizer,
strategy=DatasetStrategy.TRAIN,
)
)
)[0]


data_collator = self._get_data_collator(experiment_settings, self.tokenizer)

Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/settings/tf/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ class TrainerSettings(ExtraFieldsNotAllowedBaseModel):
gradient_checkpointing_kwargs: dict[str, Any] = {}
neftune_noise_alpha: float | None = None
report_to: list[str] = []
dispatch_batches: bool | None = None

0 comments on commit b63c973

Please sign in to comment.