From 40241cc743c8bd98a76b32f925a175a37ac984fd Mon Sep 17 00:00:00 2001 From: lmeribal Date: Wed, 11 Sep 2024 08:02:03 +0000 Subject: [PATCH] accelerator in dataset preprocessing --- .../pipelines/preprocessing/multimodal.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 249a1ef..6b0ea5e 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -58,28 +58,28 @@ def _process_function(self, reader, encoder, batch_file_paths, experiment_settin def _async_process_files(self, reader, encoder, files_paths, experiment_settings): batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size) - # self.accelerator.wait_for_everyone() - - # with self.accelerator.split_between_processes(batches_all) as batches: - for i, batch in enumerate(tqdm(batches_all)): - try: - logger.info(f'📖 Processing batch {i} / {len(batches_all)}') - batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) - for filename, encoded_output in batch_output.items(): - torch.save( - encoded_output, - experiment_settings.output_file_path - / ( - filename - + '.' - + experiment_settings.modality.value - + '.' - + experiment_settings.encoder_settings.modality_encoder_type - + '.pt' - ), - ) - except Exception as exc: - logger.error(f'Error reading file: {exc}') + self.accelerator.wait_for_everyone() + + with self.accelerator.split_between_processes(batches_all) as batches: + for i, batch in enumerate(tqdm(batches)): + try: + logger.info(f'📖 Processing batch {i} / {len(batches)}') + batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) + for filename, encoded_output in batch_output.items(): + torch.save( + encoded_output, + experiment_settings.output_file_path + / ( + filename + + '.' + + experiment_settings.modality.value + + '.' + + experiment_settings.encoder_settings.modality_encoder_type + + '.pt' + ), + ) + except Exception as exc: + logger.error(f'Error reading file: {exc}') def _load_modality_reader_encoder( self,