From fa400617d0213d239ae07a75b17b105d372ee10f Mon Sep 17 00:00:00 2001 From: lmeribal Date: Wed, 11 Sep 2024 15:52:04 +0300 Subject: [PATCH] debugging processer --- .../pipelines/preprocessing/multimodal.py | 219 ++++++++++-------- 1 file changed, 122 insertions(+), 97 deletions(-) diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 09460d6..ff8fbde 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -5,6 +5,7 @@ import numpy as np import torch from accelerate import Accelerator +import os from accelerate.utils import gather_object from allenai_common import Params from safetensors.torch import save_file @@ -18,108 +19,132 @@ from turbo_alignment.modeling.multimodal.encoders.base import BaseModalityEncoder from turbo_alignment.pipelines.base import BaseStrategy from turbo_alignment.settings.datasets.multimodal import ( - MultimodalDatasetProcessingSettings, + MultimodalDatasetProcessingSettings, ) from turbo_alignment.settings.modality import ( - Modality, - ModalityEncoderSettings, - ModalityReaderSettings, + Modality, + ModalityEncoderSettings, + ModalityReaderSettings, ) logger = get_project_logger() class PreprocessMultimodalDatasetStrategy(BaseStrategy): - def __init__(self, *args, **kwargs): - self.accelerator = Accelerator() - - def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: - if self.accelerator.is_main_process: - logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') - - reader, encoder = self._load_modality_reader_encoder( - experiment_settings.reader_settings, - experiment_settings.encoder_settings, - experiment_settings.modality, - ) - self._read_modality_objects(reader, encoder, experiment_settings) - - if self.accelerator.is_main_process: - logger.info(f'👩 Saved!') - - def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx): - modality_objects = [] - for file_path in batch_file_paths: - modality_objects.append(reader.read(str(file_path))) - modality_objects = torch.cat(modality_objects) - encoded_modality_objects = encoder.encode(modality_objects.to(self.accelerator.device)).detach().cpu() - safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths) - - return safetensors_dict_batch - - def _async_process_files(self, reader, encoder, files_paths, experiment_settings): - batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) - - self.accelerator.wait_for_everyone() - - for i, batches in enumerate(tqdm(batches_all)): - if self.accelerator.is_main_process: - logger.info(f'📖 Processing batch {i} / {len(batches_all)}') - with self.accelerator.split_between_processes(batches) as batch: - try: - batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) - except Exception as exc: - logger.error(f'Error reading file: {exc}') - self.accelerator.wait_for_everyone() - - all_filenames = gather_object([filename for filename in batch_output.keys()]) - all_tensors = gather_object([tensor for tensor in batch_output.values()]) - - if self.accelerator.is_main_process: - batch_output_from_all_processes = {f: t for f, t in zip(all_filenames, all_tensors)} - for filename, encoded_output in batch_output_from_all_processes.items(): - filepath = experiment_settings.output_file_path / ( - filename - + '.' - + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - ) - - with open(filepath, 'wb') as f: - torch.save(encoded_output, f) - - def _load_modality_reader_encoder( - self, - reader_settings: ModalityReaderSettings, - encoder_settings: ModalityEncoderSettings, - modality: Modality, - ) -> Tuple[BaseModalityReader, BaseModalityEncoder]: - device = self.accelerator.device - reader = ModalityReaderRegistry.by_name(modality).from_params( - Params({'type': reader_settings.reader_type, 'reader_path': reader_settings.reader_path}) - ) - encoder = ModalityEncoderRegistry.by_name(encoder_settings.modality_encoder_type)( - encoder_path=encoder_settings.encoder_path - ).to(device) - return (reader, encoder) - - def _read_modality_objects(self, reader, encoder, experiment_settings): - modality_tensors = [] - - available_extensions = ('jpg', 'jpeg', 'png', 'svg') - - logger.info('📖 Reading modality objects...') - files_paths: list[Path] = [] - for extension in available_extensions: - files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}')) - - self._async_process_files(reader, encoder, files_paths, experiment_settings) - - @staticmethod - def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths): - tensors = {} - for file, tensor in zip(encoded_file_paths, encoded_modality_tensors): - tensors[file.name] = tensor.detach() - return tensors + def __init__(self, *args, **kwargs): + self.accelerator = Accelerator() + + def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: + if self.accelerator.is_main_process: + logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') + + reader, encoder = self._load_modality_reader_encoder( + experiment_settings.reader_settings, + experiment_settings.encoder_settings, + experiment_settings.modality, + ) + self._read_modality_objects(reader, encoder, experiment_settings) + + if self.accelerator.is_main_process: + logger.info(f'👩 Saved!') + + def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx): + modality_objects = [] + for file_path in batch_file_paths: + modality_objects.append(reader.read(str(file_path))) + modality_objects = torch.cat(modality_objects) + encoded_modality_objects = encoder.encode(modality_objects.to(self.accelerator.device)).detach().cpu() + logger.info(f'encoded a batch: {encoded_modality_objects.shape}') + safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths) + + return safetensors_dict_batch + + + def _process_files(self, reader, encoder, files_paths, experiment_settings): + logger.info(f'👩 Processing without accelerate!') + batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) + + for i, batches in enumerate(tqdm(batches_all)): + logger.info(f'📖 Processing batch {i} / {len(batches_all)}') + try: + batch_output = self._process_function(reader, encoder, batches, experiment_settings, i) + for filename, encoded_output in batch_output.items(): + filepath = experiment_settings.output_file_path / ( + filename + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ) + logger.info(f'📖 Processing file {filepath}, {encoded_output.shape}') + torch.save(encoded_output, filepath) + logger.info(f'📖 Saved') + except Exception as exc: + logger.error(f'Error reading file: {exc}') + + + def _async_process_files(self, reader, encoder, files_paths, experiment_settings): + logger.info(f'👩 Processing with accelerate!') + batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) + + self.accelerator.wait_for_everyone() + + with self.accelerator.split_between_processes(batches_all) as batches: + for i, batch in enumerate(tqdm(batches)): + try: + logger.info(f'📖 Processing batch {i} / {len(batches)}') + batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) + for filename, encoded_output in batch_output.items(): + torch.save( + encoded_output, + experiment_settings.output_file_path + / ( + filename + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ), + ) + except Exception as exc: + logger.error(f'Error reading file: {exc}') + + def _load_modality_reader_encoder( + self, + reader_settings: ModalityReaderSettings, + encoder_settings: ModalityEncoderSettings, + modality: Modality, + ) -> Tuple[BaseModalityReader, BaseModalityEncoder]: + device = self.accelerator.device + reader = ModalityReaderRegistry.by_name(modality).from_params( + Params({'type': reader_settings.reader_type, 'reader_path': reader_settings.reader_path}) + ) + encoder = ModalityEncoderRegistry.by_name(encoder_settings.modality_encoder_type)( + encoder_path=encoder_settings.encoder_path + ).to(device) + return (reader, encoder) + + def _read_modality_objects(self, reader, encoder, experiment_settings): + modality_tensors = [] + + available_extensions = ('jpg', 'jpeg', 'png', 'svg') + + if self.accelerator.is_main_process: + logger.info('📖 Reading modality objects...') + files_paths: list[Path] = [] + for extension in available_extensions: + files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}')) + + if os.environ.get('ACCELERATE_ENABLED', 'false') == 'true': + self._async_process_files(reader, encoder, files_paths, experiment_settings) + else: + self._process_files(reader, encoder, files_paths, experiment_settings) + + @staticmethod + def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths): + tensors = {} + for file, tensor in zip(encoded_file_paths, encoded_modality_tensors): + tensors[file.name] = tensor.detach() + return tensors