From b63c9739504cbe14e9a8c4452b5917fc9a695e78 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Thu, 12 Sep 2024 15:42:39 +0000 Subject: [PATCH] IterableDataset for multimodal pipeline --- .../common/data/multimodal/common.py | 1 + turbo_alignment/dataset/base/base.py | 86 ++++++++++++- .../dataset/multimodal/collators.py | 20 ++- .../dataset/multimodal/multimodal.py | 116 ++++++++++-------- .../pipelines/preprocessing/multimodal.py | 1 + turbo_alignment/pipelines/train/base.py | 34 +++-- turbo_alignment/settings/tf/trainer.py | 1 + 7 files changed, 196 insertions(+), 63 deletions(-) diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py index 3a1d531..e09c624 100644 --- a/turbo_alignment/common/data/multimodal/common.py +++ b/turbo_alignment/common/data/multimodal/common.py @@ -21,6 +21,7 @@ def _get_pt_files(path: Path) -> Path: return list(path.glob('*.pt')) def read(self, path: str) -> torch.Tensor: + print('Calling a read!') if self.processed_batches is None: self.processed_batches = {} pt_files = self._get_pt_files(Path(path).parent) diff --git a/turbo_alignment/dataset/base/base.py b/turbo_alignment/dataset/base/base.py index 71e353e..2183b80 100755 --- a/turbo_alignment/dataset/base/base.py +++ b/turbo_alignment/dataset/base/base.py @@ -4,7 +4,7 @@ from typing import Any, Generic, TypeVar, overload import torch -from torch.utils.data import Dataset +from torch.utils.data import Dataset, IterableDataset from transformers import PreTrainedTokenizerBase from turbo_alignment.common.logging import get_project_logger @@ -101,9 +101,91 @@ def _read_records(records: list[dict]) -> list[RecordT]: @abstractmethod def _read_records(records): ... + +class BaseIterableDataset(IterableDataset, ABC, Generic[RecordT]): + def __init__( + self, + source: DatasetSourceSettings, + settings: BaseDatasetSettings, + ) -> None: + self.source = source + self.settings = settings + + self.original_records_map: dict[str, RecordT] = {} + self.records: list[dict[str, torch.Tensor]] = [] + + def _read(self) -> None: + if self.source.records_data: + records = self._read_records(self.source.records_data) + elif self.source.records_path: + records = self._read_records(self.source.records_path) + else: + raise ValueError('At least one of records_data and records_path should be not None') + + if self.source.offset is not None and self.source.n_rows is not None: + records = records[self.source.offset : self.source.offset + self.source.n_rows] + + self.original_records_map, self.records = self._sample_dataset(records) + + logger.info(f'Sampled {len(self.records)} records with offset {self.source.offset}') + + def _sample_dataset( + self, + original_records: list[RecordT], + ) -> tuple[dict[str, RecordT], list[dict[str, Any]]]: + if self.source.sample_rate is not None: + logger.info(f'Sampling dataset {self.source.name} with sample rate: {self.source.sample_rate}') + sampled_original_records = { + record.id: record for record in original_records if random.random() <= self.source.sample_rate + } + elif self.source.num_samples is not None: + logger.info(f'Sampling {self.source.num_samples} from dataset {self.source.name}') + sampled_original_records = { + record.id: record + for record in random.sample(original_records, k=min(self.source.num_samples, len(original_records))) + } + else: + raise ValueError('neither sample_rate nor num_samples are not set') + + sampled_records = [r for r in self.convert_records(list(sampled_original_records.values())) if r is not None] + + return sampled_original_records, sampled_records + + def __len__(self) -> int: + return len(self.records) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self.records[index] + + def __iter__(self): + return iter(self.records) + + def get_original_record_by_id(self, record_id: str) -> RecordT: + return self.original_records_map[record_id] + + @abstractmethod + def convert_records(self, records: list[RecordT]) -> list[dict[str, Any] | None]: + ... + + @staticmethod + @abstractmethod + @overload + def _read_records(records: Path) -> list[RecordT]: + ... + + @staticmethod + @abstractmethod + @overload + def _read_records(records: list[dict]) -> list[RecordT]: + ... + + @staticmethod + @abstractmethod + def _read_records(records): + ... -class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]): +class AlignmentDataset(BaseIterableDataset, ABC, Generic[RecordT]): def __init__( self, source: DatasetSourceSettings, diff --git a/turbo_alignment/dataset/multimodal/collators.py b/turbo_alignment/dataset/multimodal/collators.py index 63911d3..71c9295 100644 --- a/turbo_alignment/dataset/multimodal/collators.py +++ b/turbo_alignment/dataset/multimodal/collators.py @@ -8,9 +8,20 @@ def torch_call(self, features): label_name = 'label' if 'label' in features[0].keys() else 'labels' labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None - modality_inputs = [feature['modality_inputs'] for feature in features] - - modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask') + if 'modality_inputs' in features[0].keys(): + modality_inputs = [feature['modality_inputs'] for feature in features] + else: + modality_inputs = [None for feature in features] + + if 'messages' in features[0].keys(): + message_inputs = [feature['messages'] for feature in features] + else: + message_inputs = [None] * len(features) + + if 'messages' in features[0].keys(): + modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask', 'messages') + else: + modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask') tokenizer_features = [ {k: v for k, v in feature.items() if k not in modality_input_names} for feature in features ] @@ -29,6 +40,9 @@ def torch_call(self, features): assert padding_side == 'right' batch['modality_inputs'] = modality_inputs + + if 'messages' in features[0].keys(): + batch['messages'] = message_inputs batch['modality_tokens_mask'] = torch.stack( [ diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index 401604d..45e3d66 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -1,4 +1,5 @@ from abc import ABC +import math from pathlib import Path from typing import Any, overload @@ -116,26 +117,6 @@ def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord return ChatDatasetRecord(id=record.id, messages=converted_messages) - def _read_modalities( - self, record: MultimodalDatasetRecord, modality_messages_after_truncation: int - ) -> list[tuple[Modality, torch.Tensor]]: - modality_messages: list[MultimodalFileMessage] = [ - m for m in record.messages if isinstance(m, MultimodalFileMessage) - ] - - messages_to_delete = len(modality_messages) - modality_messages_after_truncation - - if self._truncate_top: - modality_messages = modality_messages[messages_to_delete:] - else: - modality_messages = modality_messages[:modality_messages_after_truncation] - - modality_encodings: list[tuple[Modality, torch.Tensor]] = [] - for msg in modality_messages: - reader = self._modality_readers[msg.type] - modality_encodings.append((msg.type, reader.read(msg.content))) - return modality_encodings - @MultimodalDatasetTypeRegistry.register(DatasetStrategy.TRAIN) class TrainMultimodalDataset(MultimodalDataset): @@ -169,19 +150,15 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append(None) continue - modality_messages_after_truncation = int( - (tokenized_record['input_ids'] == self._get_token_id(self._start_modality_token)).sum() - ) + # try: + # encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) + # except (OSError, RuntimeError, KeyError): + # outputs.append(None) + # continue - try: - encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) - except (OSError, RuntimeError, KeyError): - outputs.append(None) - continue - - if len(encoded_modalities) != modality_messages_after_truncation: - outputs.append(None) - continue + # if len(encoded_modalities) != modality_messages_after_truncation: + # outputs.append(None) + # continue modality_tokens_mask = torch.isin( tokenized_record['input_ids'], @@ -193,13 +170,49 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append( { **tokenized_record, - 'modality_inputs': encoded_modalities, + # 'modality_inputs': encoded_modalities, + 'messages': record.messages, 'modality_tokens_mask': modality_tokens_mask, } ) return outputs + def _read_modalities(self, record): + modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum()) + + modality_messages: list[MultimodalFileMessage] = [ + m for m in record['messages'] if isinstance(m, MultimodalFileMessage) + ] + + messages_to_delete = len(modality_messages) - modality_messages_after_truncation + + if self._truncate_top: + modality_messages = modality_messages[messages_to_delete:] + else: + modality_messages = modality_messages[:modality_messages_after_truncation] + + modality_encodings: list[tuple[Modality, torch.Tensor]] = [] + for msg in modality_messages: + reader = self._modality_readers[msg.type] + modality_encodings.append((msg.type, reader.read(msg.content))) + record['modality_inputs'] = modality_encodings + return record + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + + start = 0 + end = len(self.records) - 1 + if worker_info: + per_worker = int(math.ceil(len(self.records) / float(worker_info.num_workers))) + worker_id = worker_info.id + start = start + worker_id * per_worker + end = min(start + per_worker, end) + print("🩻"*10, f"{float(worker_info.num_workers)=}, {per_worker=}, {worker_info.id=}, {start=}, {end=}") + + return map(self._read_modalities, iter(self.records[start:end])) + @MultimodalDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) class InferenceMultimodalDataset(MultimodalDataset): @@ -216,6 +229,27 @@ def __init__( tokenizer=kwargs['tokenizer'], source=kwargs['source'], settings=settings, read=False ) self._read() + + def _read_modalities(self, record): + modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum()) + + modality_messages: list[MultimodalFileMessage] = [ + m for m in record['messages'] if isinstance(m, MultimodalFileMessage) + ] + + messages_to_delete = len(modality_messages) - modality_messages_after_truncation + + if self._truncate_top: + modality_messages = modality_messages[messages_to_delete:] + else: + modality_messages = modality_messages[:modality_messages_after_truncation] + + modality_encodings: list[tuple[Modality, torch.Tensor]] = [] + for msg in modality_messages: + reader = self._modality_readers[msg.type] + modality_encodings.append((msg.type, reader.read(msg.content))) + record['modality_inputs'] = modality_encodings + return record def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]: chat_records = [self._convert_to_chat(r) for r in records] @@ -228,20 +262,6 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append(None) continue - modality_messages_after_truncation = int( - (tokenized_record['input_ids'] == self._get_token_id(self._start_modality_token)).sum() - ) - - try: - encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) - except (OSError, RuntimeError, KeyError): - outputs.append(None) - continue - - if len(encoded_modalities) != modality_messages_after_truncation: - outputs.append(None) - continue - modality_tokens_mask = torch.isin( tokenized_record['input_ids'], torch.tensor([self._get_token_id(token) for token in self._modality_token_mapping.values()]), @@ -255,7 +275,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append( { **tokenized_record, - 'modality_inputs': encoded_modalities, + 'messages': record.messages, 'modality_tokens_mask': modality_tokens_mask, 'modality_object_paths': modality_object_paths, } diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index d9ea1fa..7c9ef73 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -4,6 +4,7 @@ from typing import Tuple import numpy as np + import torch from accelerate import Accelerator import os diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 57e4a20..2f5cdad 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -156,22 +156,36 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: special_tokens_setter.setup_model_config(self.model) logger.info('Model is loaded!') - - train_dataset: ConcatDataset = ConcatDataset( - datasets=DatasetLoader().load_datasets( + + + # train_dataset: ConcatDataset = ConcatDataset( + # datasets=DatasetLoader().load_datasets( + # experiment_settings.train_dataset_settings, + # tokenizer=self.tokenizer, + # strategy=DatasetStrategy.TRAIN, + # ) + # ) + + # val_dataset: ConcatDataset = ConcatDataset( + # datasets=DatasetLoader().load_datasets( + # experiment_settings.val_dataset_settings, + # tokenizer=self.tokenizer, + # strategy=DatasetStrategy.TRAIN, + # ) + # ) + + train_dataset = datasets=DatasetLoader().load_datasets( experiment_settings.train_dataset_settings, tokenizer=self.tokenizer, strategy=DatasetStrategy.TRAIN, - ) - ) - - val_dataset: ConcatDataset = ConcatDataset( - datasets=DatasetLoader().load_datasets( + )[0] + + val_dataset = datasets=DatasetLoader().load_datasets( experiment_settings.val_dataset_settings, tokenizer=self.tokenizer, strategy=DatasetStrategy.TRAIN, - ) - ) + )[0] + data_collator = self._get_data_collator(experiment_settings, self.tokenizer) diff --git a/turbo_alignment/settings/tf/trainer.py b/turbo_alignment/settings/tf/trainer.py index 662ee62..0d17ccd 100755 --- a/turbo_alignment/settings/tf/trainer.py +++ b/turbo_alignment/settings/tf/trainer.py @@ -47,3 +47,4 @@ class TrainerSettings(ExtraFieldsNotAllowedBaseModel): gradient_checkpointing_kwargs: dict[str, Any] = {} neftune_noise_alpha: float | None = None report_to: list[str] = [] + dispatch_batches: bool | None = None