From 8b44d5ff5349abe6baebc2c29e471c82501041a6 Mon Sep 17 00:00:00 2001 From: lmeribal Date: Fri, 13 Sep 2024 10:39:18 +0000 Subject: [PATCH] iterable dataset --- .../common/data/multimodal/common.py | 1 - turbo_alignment/dataset/base/__init__.py | 2 +- turbo_alignment/dataset/base/base.py | 23 +++- .../dataset/multimodal/multimodal.py | 110 +++++++++++------- turbo_alignment/pipelines/train/base.py | 4 +- 5 files changed, 96 insertions(+), 44 deletions(-) diff --git a/turbo_alignment/common/data/multimodal/common.py b/turbo_alignment/common/data/multimodal/common.py index e09c624..3a1d531 100644 --- a/turbo_alignment/common/data/multimodal/common.py +++ b/turbo_alignment/common/data/multimodal/common.py @@ -21,7 +21,6 @@ def _get_pt_files(path: Path) -> Path: return list(path.glob('*.pt')) def read(self, path: str) -> torch.Tensor: - print('Calling a read!') if self.processed_batches is None: self.processed_batches = {} pt_files = self._get_pt_files(Path(path).parent) diff --git a/turbo_alignment/dataset/base/__init__.py b/turbo_alignment/dataset/base/__init__.py index 8414b37..4302130 100755 --- a/turbo_alignment/dataset/base/__init__.py +++ b/turbo_alignment/dataset/base/__init__.py @@ -1,2 +1,2 @@ -from .base import AlignmentDataset, BaseDataset +from .base import AlignmentDataset, AlignmentIterableDataset, BaseDataset from .models import DatasetRecord diff --git a/turbo_alignment/dataset/base/base.py b/turbo_alignment/dataset/base/base.py index fcba11b..7f6fc01 100755 --- a/turbo_alignment/dataset/base/base.py +++ b/turbo_alignment/dataset/base/base.py @@ -186,7 +186,28 @@ def _read_records(records): ... -class AlignmentDataset(BaseIterableDataset, ABC, Generic[RecordT]): +class AlignmentDataset(BaseDataset, ABC, Generic[RecordT]): + def __init__( + self, + source: DatasetSourceSettings, + settings: BaseDatasetSettings, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + super().__init__(source=source, settings=settings) + + self.tokenizer = tokenizer + self._logged = False + + def _log_example(self, prompt: str, answer: str | None = None) -> None: + if not self._logged: + message = f'Source and target examples:\n' f'Prompt: {prompt}\n' + if answer: + message += f'Answer: {answer}' + logger.info(message) + self._logged = True + + +class AlignmentIterableDataset(BaseIterableDataset, ABC, Generic[RecordT]): def __init__( self, source: DatasetSourceSettings, diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index 383e03b..5c616ad 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -12,7 +12,7 @@ from turbo_alignment.common.data.multimodal.registry import ModalityReaderRegistry from turbo_alignment.common.logging import get_project_logger from turbo_alignment.constants import DISABLE_LOSS_LABEL -from turbo_alignment.dataset.base import AlignmentDataset +from turbo_alignment.dataset.base import AlignmentDataset, AlignmentIterableDataset from turbo_alignment.dataset.chat.chat import InferenceChatDataset, TrainChatDataset from turbo_alignment.dataset.chat.models import ChatDatasetRecord, ChatMessage from turbo_alignment.dataset.multimodal.models import ( @@ -29,7 +29,7 @@ logger = get_project_logger() -class MultimodalDataset(AlignmentDataset[MultimodalDatasetRecord], ABC): +class MultimodalDataset(AlignmentIterableDataset[MultimodalDatasetRecord], ABC): def __init__(self, tokenizer, source, settings): super().__init__(tokenizer=tokenizer, source=source, settings=settings) @@ -93,29 +93,6 @@ def __get_modality_message(self, modality: Modality) -> str: modality_message_span = ''.join(modality_token for _ in range(self._n_modality_embeddings)) return f'{self._start_modality_token}{modality_message_span}{self._end_modality_token}' - def _read_modalities(self, record): - modality_messages_after_truncation = int( - (self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum() - ) - - modality_messages: list[MultimodalFileMessage] = [ - m for m in record['messages'] if isinstance(m, MultimodalFileMessage) - ] - - messages_to_delete = len(modality_messages) - modality_messages_after_truncation - - if self._truncate_top: - modality_messages = modality_messages[messages_to_delete:] - else: - modality_messages = modality_messages[:modality_messages_after_truncation] - - modality_encodings: list[tuple[Modality, torch.Tensor]] = [] - for msg in modality_messages: - reader = self._modality_readers[msg.type] - modality_encodings.append((msg.type, reader.read(msg.content))) - record['modality_inputs'] = modality_encodings - return record - def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord: """ Обычные текстовые сообщения оставляем без изменений, а мультимодальные сообщения @@ -145,7 +122,6 @@ def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord class TrainMultimodalDataset(MultimodalDataset): def __init__(self, tokenizer, source, settings) -> None: """ - :param n_modality_embeddings: сколько токенов выделяем под одно сообщение с нетекстовой модальностью :param modality_token_mapping: modality -> token :param start_modality_token: начало блока с токенами, которые относятся к нетекстовой модальности @@ -183,7 +159,6 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append( { **tokenized_record, - # 'modality_inputs': encoded_modalities, 'messages': record.messages, 'modality_tokens_mask': modality_tokens_mask, } @@ -191,17 +166,38 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s return outputs - def __iter__(self): - # try: - # encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) - # except (OSError, RuntimeError, KeyError): - # outputs.append(None) - # continue + def _read_modalities(self, record): + modality_messages_after_truncation = int( + (self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum() + ) + + modality_messages: list[MultimodalFileMessage] = [ + m for m in record['messages'] if isinstance(m, MultimodalFileMessage) + ] + + messages_to_delete = len(modality_messages) - modality_messages_after_truncation + + if self._truncate_top: + modality_messages = modality_messages[messages_to_delete:] + else: + modality_messages = modality_messages[:modality_messages_after_truncation] + + modality_encodings: list[tuple[Modality, torch.Tensor]] = [] + try: + for msg in modality_messages: + reader = self._modality_readers[msg.type] + modality_encodings.append((msg.type, reader.read(msg.content))) + except (OSError, RuntimeError, KeyError): + return None + + record['modality_inputs'] = modality_encodings + + if len(modality_encodings) != modality_messages_after_truncation: + return None - # if len(encoded_modalities) != modality_messages_after_truncation: - # outputs.append(None) - # continue + return record + def __iter__(self): worker_info = torch.utils.data.get_worker_info() start = 0 @@ -211,8 +207,9 @@ def __iter__(self): worker_id = worker_info.id start = start + worker_id * per_worker end = min(start + per_worker, end) - - return map(self._read_modalities, iter(self.records[start:end])) + for sample in map(self._read_modalities, iter(self.records[start:end])): + if sample: + yield sample @MultimodalDatasetTypeRegistry.register(DatasetStrategy.INFERENCE) @@ -231,6 +228,27 @@ def __init__( ) self._read() + def _read_modalities( + self, record: MultimodalDatasetRecord, modality_messages_after_truncation: int + ) -> list[tuple[Modality, torch.Tensor]]: + modality_messages: list[MultimodalFileMessage] = [ + m for m in record.messages if isinstance(m, MultimodalFileMessage) + ] + + messages_to_delete = len(modality_messages) - modality_messages_after_truncation + + if self._truncate_top: + modality_messages = modality_messages[messages_to_delete:] + else: + modality_messages = modality_messages[:modality_messages_after_truncation] + + modality_encodings: list[tuple[Modality, torch.Tensor]] = [] + for msg in modality_messages: + reader = self._modality_readers[msg.type] + print('inference reader') + modality_encodings.append((msg.type, reader.read(msg.content))) + return modality_encodings + def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]: chat_records = [self._convert_to_chat(r) for r in records] tokenized_chat_records = self._chat_dataset.convert_records(chat_records) @@ -242,6 +260,20 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append(None) continue + modality_messages_after_truncation = int( + (tokenized_record['input_ids'] == self._get_token_id(self._start_modality_token)).sum() + ) + + try: + encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) + except (OSError, RuntimeError, KeyError): + outputs.append(None) + continue + + if len(encoded_modalities) != modality_messages_after_truncation: + outputs.append(None) + continue + modality_tokens_mask = torch.isin( tokenized_record['input_ids'], torch.tensor([self._get_token_id(token) for token in self._modality_token_mapping.values()]), @@ -255,7 +287,7 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append( { **tokenized_record, - 'messages': record.messages, + 'modality_inputs': encoded_modalities, 'modality_tokens_mask': modality_tokens_mask, 'modality_object_paths': modality_object_paths, } diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 78e52c2..0aff572 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -173,13 +173,13 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: # ) # ) - train_dataset = datasets = DatasetLoader().load_datasets( + train_dataset = DatasetLoader().load_datasets( experiment_settings.train_dataset_settings, tokenizer=self.tokenizer, strategy=DatasetStrategy.TRAIN, )[0] - val_dataset = datasets = DatasetLoader().load_datasets( + val_dataset = DatasetLoader().load_datasets( experiment_settings.val_dataset_settings, tokenizer=self.tokenizer, strategy=DatasetStrategy.TRAIN,