diff --git a/turbo_alignment/dataset/base/base.py b/turbo_alignment/dataset/base/base.py index 2183b80..fcba11b 100755 --- a/turbo_alignment/dataset/base/base.py +++ b/turbo_alignment/dataset/base/base.py @@ -101,7 +101,8 @@ def _read_records(records: list[dict]) -> list[RecordT]: @abstractmethod def _read_records(records): ... - + + class BaseIterableDataset(IterableDataset, ABC, Generic[RecordT]): def __init__( self, diff --git a/turbo_alignment/dataset/multimodal/collators.py b/turbo_alignment/dataset/multimodal/collators.py index 71c9295..8f75356 100644 --- a/turbo_alignment/dataset/multimodal/collators.py +++ b/turbo_alignment/dataset/multimodal/collators.py @@ -12,12 +12,12 @@ def torch_call(self, features): modality_inputs = [feature['modality_inputs'] for feature in features] else: modality_inputs = [None for feature in features] - + if 'messages' in features[0].keys(): message_inputs = [feature['messages'] for feature in features] else: message_inputs = [None] * len(features) - + if 'messages' in features[0].keys(): modality_input_names = (label_name, 'modality_inputs', 'modality_tokens_mask', 'messages') else: @@ -40,8 +40,8 @@ def torch_call(self, features): assert padding_side == 'right' batch['modality_inputs'] = modality_inputs - - if 'messages' in features[0].keys(): + + if 'messages' in features[0].keys(): batch['messages'] = message_inputs batch['modality_tokens_mask'] = torch.stack( diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index 45e3d66..383e03b 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -1,5 +1,5 @@ -from abc import ABC import math +from abc import ABC from pathlib import Path from typing import Any, overload @@ -93,6 +93,29 @@ def __get_modality_message(self, modality: Modality) -> str: modality_message_span = ''.join(modality_token for _ in range(self._n_modality_embeddings)) return f'{self._start_modality_token}{modality_message_span}{self._end_modality_token}' + def _read_modalities(self, record): + modality_messages_after_truncation = int( + (self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum() + ) + + modality_messages: list[MultimodalFileMessage] = [ + m for m in record['messages'] if isinstance(m, MultimodalFileMessage) + ] + + messages_to_delete = len(modality_messages) - modality_messages_after_truncation + + if self._truncate_top: + modality_messages = modality_messages[messages_to_delete:] + else: + modality_messages = modality_messages[:modality_messages_after_truncation] + + modality_encodings: list[tuple[Modality, torch.Tensor]] = [] + for msg in modality_messages: + reader = self._modality_readers[msg.type] + modality_encodings.append((msg.type, reader.read(msg.content))) + record['modality_inputs'] = modality_encodings + return record + def _convert_to_chat(self, record: MultimodalDatasetRecord) -> ChatDatasetRecord: """ Обычные текстовые сообщения оставляем без изменений, а мультимодальные сообщения @@ -150,16 +173,6 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s outputs.append(None) continue - # try: - # encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) - # except (OSError, RuntimeError, KeyError): - # outputs.append(None) - # continue - - # if len(encoded_modalities) != modality_messages_after_truncation: - # outputs.append(None) - # continue - modality_tokens_mask = torch.isin( tokenized_record['input_ids'], torch.tensor([self._get_token_id(token) for token in self._modality_token_mapping.values()]), @@ -178,30 +191,19 @@ def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[s return outputs - def _read_modalities(self, record): - modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum()) - - modality_messages: list[MultimodalFileMessage] = [ - m for m in record['messages'] if isinstance(m, MultimodalFileMessage) - ] - - messages_to_delete = len(modality_messages) - modality_messages_after_truncation - - if self._truncate_top: - modality_messages = modality_messages[messages_to_delete:] - else: - modality_messages = modality_messages[:modality_messages_after_truncation] + def __iter__(self): + # try: + # encoded_modalities = self._read_modalities(record, modality_messages_after_truncation) + # except (OSError, RuntimeError, KeyError): + # outputs.append(None) + # continue - modality_encodings: list[tuple[Modality, torch.Tensor]] = [] - for msg in modality_messages: - reader = self._modality_readers[msg.type] - modality_encodings.append((msg.type, reader.read(msg.content))) - record['modality_inputs'] = modality_encodings - return record + # if len(encoded_modalities) != modality_messages_after_truncation: + # outputs.append(None) + # continue - def __iter__(self): worker_info = torch.utils.data.get_worker_info() - + start = 0 end = len(self.records) - 1 if worker_info: @@ -209,7 +211,6 @@ def __iter__(self): worker_id = worker_info.id start = start + worker_id * per_worker end = min(start + per_worker, end) - print("🩻"*10, f"{float(worker_info.num_workers)=}, {per_worker=}, {worker_info.id=}, {start=}, {end=}") return map(self._read_modalities, iter(self.records[start:end])) @@ -229,27 +230,6 @@ def __init__( tokenizer=kwargs['tokenizer'], source=kwargs['source'], settings=settings, read=False ) self._read() - - def _read_modalities(self, record): - modality_messages_after_truncation = int((self.records[0]['input_ids'] == self._get_token_id(self._start_modality_token)).sum()) - - modality_messages: list[MultimodalFileMessage] = [ - m for m in record['messages'] if isinstance(m, MultimodalFileMessage) - ] - - messages_to_delete = len(modality_messages) - modality_messages_after_truncation - - if self._truncate_top: - modality_messages = modality_messages[messages_to_delete:] - else: - modality_messages = modality_messages[:modality_messages_after_truncation] - - modality_encodings: list[tuple[Modality, torch.Tensor]] = [] - for msg in modality_messages: - reader = self._modality_readers[msg.type] - modality_encodings.append((msg.type, reader.read(msg.content))) - record['modality_inputs'] = modality_encodings - return record def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]: chat_records = [self._convert_to_chat(r) for r in records] diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 7c9ef73..6e5ff01 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -1,13 +1,11 @@ -from concurrent.futures import ThreadPoolExecutor, as_completed -from concurrent.futures import ProcessPoolExecutor +import os +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from pathlib import Path from typing import Tuple import numpy as np - import torch from accelerate import Accelerator -import os from accelerate.utils import gather_object from allenai_common import Params from safetensors.torch import save_file @@ -21,131 +19,145 @@ from turbo_alignment.modeling.multimodal.encoders.base import BaseModalityEncoder from turbo_alignment.pipelines.base import BaseStrategy from turbo_alignment.settings.datasets.multimodal import ( - MultimodalDatasetProcessingSettings, + MultimodalDatasetProcessingSettings, ) from turbo_alignment.settings.modality import ( - Modality, - ModalityEncoderSettings, - ModalityReaderSettings, + Modality, + ModalityEncoderSettings, + ModalityReaderSettings, ) logger = get_project_logger() class PreprocessMultimodalDatasetStrategy(BaseStrategy): - def __init__(self, *args, **kwargs): - self.accelerator = Accelerator() - - def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: - if self.accelerator.is_main_process: - logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') - - reader, encoder = self._load_modality_reader_encoder( - experiment_settings.reader_settings, - experiment_settings.encoder_settings, - experiment_settings.modality, - ) - self._read_modality_objects(reader, encoder, experiment_settings) - - if self.accelerator.is_main_process: - logger.info(f'👩 Saved!') - - def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx): - modality_objects = [] - for file_path in batch_file_paths: - modality_objects.append(reader.read(str(file_path))) - modality_objects = torch.cat(modality_objects) - encoded_modality_objects = encoder.encode(modality_objects.to(self.accelerator.device)).detach().cpu() - safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths) - - return safetensors_dict_batch - - @staticmethod - def _save_tensor(tensor, filename, experiment_settings): - logger.info(f'saving {filepath}', tensor.shape) - filepath = experiment_settings.output_file_path / ( - filename - + '.' - + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - ) - torch.save(tensor, filepath) - - def _process_files(self, reader, encoder, files_paths, experiment_settings): - batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) - - for i, batch in enumerate(tqdm(batches_all)): - try: - logger.info(f'📖 Processing batch {i} / {len(batches_all)}') - batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) - torch.save(batch_output, experiment_settings.output_file_path / ( - 'batch_' + str(i) + '.' + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - ) - ) - except Exception as exc: - logger.error(f'Error reading file: {exc}') - - def _async_process_files(self, reader, encoder, files_paths, experiment_settings): - logger.info(f'👩 Processing with accelerate!') - batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) - - self.accelerator.wait_for_everyone() - - with self.accelerator.split_between_processes(batches_all) as batches: - for i, batch in enumerate(tqdm(batches)): - try: - logger.info(f'📖 Encoding batch {i} / {len(batches)}') - batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) - torch.save(batch_output, experiment_settings.output_file_path / ( - 'process_' + str(self.accelerator.process_index) + '_batch_' + str(i) + '.' + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - ) - ) - - except Exception as exc: - logger.error(f'Error reading file: {exc}') - - def _load_modality_reader_encoder( - self, - reader_settings: ModalityReaderSettings, - encoder_settings: ModalityEncoderSettings, - modality: Modality, - ) -> Tuple[BaseModalityReader, BaseModalityEncoder]: - device = self.accelerator.device - reader = ModalityReaderRegistry.by_name(modality).from_params( - Params({'type': reader_settings.reader_type, 'reader_path': reader_settings.reader_path}) - ) - encoder = ModalityEncoderRegistry.by_name(encoder_settings.modality_encoder_type)( - encoder_path=encoder_settings.encoder_path - ).to(device) - return (reader, encoder) - - def _read_modality_objects(self, reader, encoder, experiment_settings): - modality_tensors = [] - - available_extensions = ('jpg', 'jpeg', 'png', 'svg') - - if self.accelerator.is_main_process: - logger.info('📖 Reading modality objects...') - files_paths: list[Path] = [] - for extension in available_extensions: - files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}')) - - if os.environ.get('ACCELERATE_ENABLED', 'false') == 'true': - self._async_process_files(reader, encoder, files_paths, experiment_settings) - else: - self._process_files(reader, encoder, files_paths, experiment_settings) - - @staticmethod - def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths): - tensors = {} - for file, tensor in zip(encoded_file_paths, encoded_modality_tensors): - tensors[file.name] = tensor.detach() - return tensors + def __init__(self, *args, **kwargs): + self.accelerator = Accelerator() + + def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: + if self.accelerator.is_main_process: + logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') + + reader, encoder = self._load_modality_reader_encoder( + experiment_settings.reader_settings, + experiment_settings.encoder_settings, + experiment_settings.modality, + ) + self._read_modality_objects(reader, encoder, experiment_settings) + + if self.accelerator.is_main_process: + logger.info(f'👩 Saved!') + + def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx): + modality_objects = [] + for file_path in batch_file_paths: + modality_objects.append(reader.read(str(file_path))) + modality_objects = torch.cat(modality_objects) + encoded_modality_objects = encoder.encode(modality_objects.to(self.accelerator.device)).detach().cpu() + safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths) + + return safetensors_dict_batch + + @staticmethod + def _save_tensor(tensor, filename, experiment_settings): + logger.info(f'saving {filepath}', tensor.shape) + filepath = experiment_settings.output_file_path / ( + filename + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ) + torch.save(tensor, filepath) + + def _process_files(self, reader, encoder, files_paths, experiment_settings): + batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) + + for i, batch in enumerate(tqdm(batches_all)): + try: + logger.info(f'📖 Processing batch {i} / {len(batches_all)}') + batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) + torch.save( + batch_output, + experiment_settings.output_file_path + / ( + 'batch_' + + str(i) + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ), + ) + except Exception as exc: + logger.error(f'Error reading file: {exc}') + + def _async_process_files(self, reader, encoder, files_paths, experiment_settings): + logger.info(f'👩 Processing with accelerate!') + batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) + + self.accelerator.wait_for_everyone() + + with self.accelerator.split_between_processes(batches_all) as batches: + for i, batch in enumerate(tqdm(batches)): + try: + logger.info(f'📖 Encoding batch {i} / {len(batches)}') + batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) + torch.save( + batch_output, + experiment_settings.output_file_path + / ( + 'process_' + + str(self.accelerator.process_index) + + '_batch_' + + str(i) + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ), + ) + + except Exception as exc: + logger.error(f'Error reading file: {exc}') + + def _load_modality_reader_encoder( + self, + reader_settings: ModalityReaderSettings, + encoder_settings: ModalityEncoderSettings, + modality: Modality, + ) -> Tuple[BaseModalityReader, BaseModalityEncoder]: + device = self.accelerator.device + reader = ModalityReaderRegistry.by_name(modality).from_params( + Params({'type': reader_settings.reader_type, 'reader_path': reader_settings.reader_path}) + ) + encoder = ModalityEncoderRegistry.by_name(encoder_settings.modality_encoder_type)( + encoder_path=encoder_settings.encoder_path + ).to(device) + return (reader, encoder) + + def _read_modality_objects(self, reader, encoder, experiment_settings): + modality_tensors = [] + + available_extensions = ('jpg', 'jpeg', 'png', 'svg') + + if self.accelerator.is_main_process: + logger.info('📖 Reading modality objects...') + files_paths: list[Path] = [] + for extension in available_extensions: + files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}')) + + if os.environ.get('ACCELERATE_ENABLED', 'false') == 'true': + self._async_process_files(reader, encoder, files_paths, experiment_settings) + else: + self._process_files(reader, encoder, files_paths, experiment_settings) + + @staticmethod + def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths): + tensors = {} + for file, tensor in zip(encoded_file_paths, encoded_modality_tensors): + tensors[file.name] = tensor.detach() + return tensors diff --git a/turbo_alignment/pipelines/train/base.py b/turbo_alignment/pipelines/train/base.py index 2f5cdad..78e52c2 100755 --- a/turbo_alignment/pipelines/train/base.py +++ b/turbo_alignment/pipelines/train/base.py @@ -156,8 +156,7 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: special_tokens_setter.setup_model_config(self.model) logger.info('Model is loaded!') - - + # train_dataset: ConcatDataset = ConcatDataset( # datasets=DatasetLoader().load_datasets( # experiment_settings.train_dataset_settings, @@ -173,19 +172,18 @@ def run(self, experiment_settings: ExperimentSettingsT) -> None: # strategy=DatasetStrategy.TRAIN, # ) # ) - - train_dataset = datasets=DatasetLoader().load_datasets( - experiment_settings.train_dataset_settings, - tokenizer=self.tokenizer, - strategy=DatasetStrategy.TRAIN, - )[0] - - val_dataset = datasets=DatasetLoader().load_datasets( - experiment_settings.val_dataset_settings, - tokenizer=self.tokenizer, - strategy=DatasetStrategy.TRAIN, - )[0] - + + train_dataset = datasets = DatasetLoader().load_datasets( + experiment_settings.train_dataset_settings, + tokenizer=self.tokenizer, + strategy=DatasetStrategy.TRAIN, + )[0] + + val_dataset = datasets = DatasetLoader().load_datasets( + experiment_settings.val_dataset_settings, + tokenizer=self.tokenizer, + strategy=DatasetStrategy.TRAIN, + )[0] data_collator = self._get_data_collator(experiment_settings, self.tokenizer)