diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index be24488..7ca70fc 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -6,6 +6,7 @@ import numpy as np import torch from accelerate import Accelerator +from accelerate.utils import gather_object from allenai_common import Params from safetensors.torch import save_file from tqdm import tqdm @@ -34,7 +35,8 @@ def __init__(self, *args, **kwargs): self.accelerator = Accelerator() def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: - logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') + if self.accelerator.is_main_process: + logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}') reader, encoder = self._load_modality_reader_encoder( experiment_settings.reader_settings, @@ -43,7 +45,8 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None: ) self._read_modality_objects(reader, encoder, experiment_settings) - logger.info(f'👩 Saved!') + if self.accelerator.is_main_process: + logger.info(f'👩 Saved!') def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx): modality_objects = [] @@ -60,24 +63,32 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings self.accelerator.wait_for_everyone() - with self.accelerator.split_between_processes(batches_all) as batches: - for i, batch in enumerate(tqdm(batches)): + for i, batches in enumerate(tqdm(batches_all)): + logger.info(f'📖 Processing batch {i} / {len(batches_all)}') + with self.accelerator.split_between_processes(batches) as batch: try: - logger.info(f'📖 Processing batch {i} / {len(batches)}') batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) - for filename, encoded_output in batch_output.items(): - filepath = experiment_settings.output_file_path / ( - filename - + '.' - + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - ) - with open(filepath, 'wb') as f: - torch.save(encoded_output, f) except Exception as exc: logger.error(f'Error reading file: {exc}') + self.accelerator.wait_for_everyone() + + all_filenames = gather_object([filename for filename in batch_output.keys()]) + all_tensors = gather_object([tensor for tensor in batch_output.values()]) + + if self.accelerator.is_main_process: + batch_output_from_all_processes = {f: t for f, t in zip(all_filenames, all_tensors)} + for filename, encoded_output in batch_output_from_all_processes.items(): + filepath = experiment_settings.output_file_path / ( + filename + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ) + + with open(filepath, 'wb') as f: + torch.save(encoded_output, f) def _load_modality_reader_encoder( self,