diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 4d98caa..c9af204 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Tuple +import h5py import numpy as np import torch from accelerate import Accelerator @@ -29,6 +30,50 @@ class PreprocessMultimodalDatasetStrategy(BaseStrategy): + def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: + logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') + accelerator = Accelerator() + + reader, encoder = self._load_modality_reader_encoder( + experiment_settings.reader_settings, + experiment_settings.encoder_settings, + experiment_settings.modality, + accelerator, + ) + self._read_modality_objects(reader, encoder, experiment_settings, accelerator) + + logger.info(f'👩 Saved!') + + def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx, accelerator): + modality_objects = [] + for file_path in batch_file_paths: + modality_objects.append(reader.read(str(file_path))) + modality_objects = torch.cat(modality_objects) + encoded_modality_objects = encoder.encode(modality_objects.to(accelerator.device)).detach().cpu() + safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths) + + return safetensors_dict_batch + + def _async_process_files(self, reader, encoder, files_paths, experiment_settings, accelerator): + output_file_name = experiment_settings.output_file_path / ( + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.h5' + ) + + batches = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) + + for i, batch in enumerate(tqdm(batches)): + try: + logger.info(f'📖 Processing batch {i} / {len(batches)}') + batch_output = self._process_function(reader, encoder, batch, experiment_settings, i, accelerator) + with h5py.File(output_file_name, 'a') as f: + for path, encoded_output in batch_output.items(): + f.create_dataset(path, data=encoded_output.numpy()) + except Exception as exc: + logger.error(f'Error reading file: {exc}') + @staticmethod def _load_modality_reader_encoder( reader_settings: ModalityReaderSettings, @@ -55,69 +100,15 @@ def _read_modality_objects(self, reader, encoder, experiment_settings, accelerat for extension in available_extensions: files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}')) - safetensors_full_dict = self._async_process_files( - reader, encoder, files_paths, experiment_settings, accelerator - ) - return safetensors_full_dict + self._async_process_files(reader, encoder, files_paths, experiment_settings, accelerator) @staticmethod def _build_encoder_config(encoder) -> dict: return {'emb_dim': encoder.emb_dim} - def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx, accelerator): - logger.info(f'📖 Processing batch {batch_idx}') - modality_objects = [] - for file_path in batch_file_paths: - modality_objects.append(reader.read(str(file_path))) - modality_objects = torch.cat(modality_objects) - encoded_modality_objects = encoder.encode(modality_objects.to(accelerator.device)).detach().cpu() - safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths) - - return safetensors_dict_batch - - def _async_process_files(self, reader, encoder, files_paths, experiment_settings, accelerator): - modality_tensors = [None] * len(files_paths) - - print(experiment_settings.batch_size, len(files_paths)) - batches = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) - - safetensors_full_dict = {} - for i, batch in enumerate(tqdm(batches)): - try: - batch_output = self._process_function(reader, encoder, batch, experiment_settings, i, accelerator) - safetensors_full_dict.update(batch_output) - except Exception as exc: - logger.error(f'Error reading file: {exc}') - return safetensors_full_dict - @staticmethod def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths): tensors = {} for file, tensor in zip(encoded_file_paths, encoded_modality_tensors): tensors[file.name] = tensor.detach() return tensors - - def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: - logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') - accelerator = Accelerator() - - reader, encoder = self._load_modality_reader_encoder( - experiment_settings.reader_settings, - experiment_settings.encoder_settings, - experiment_settings.modality, - accelerator, - ) - safetensors_full_dict = self._read_modality_objects(reader, encoder, experiment_settings, accelerator) - - logger.info(f'👩 Saving safetensors file...') - save_file( - safetensors_full_dict, - experiment_settings.output_file_path - / ( - experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.safetensors' - ), - ) - logger.info(f'👩 Saved!')