diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index 6b0ea5e..158dd8d 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -64,7 +64,9 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings for i, batch in enumerate(tqdm(batches)): try: logger.info(f'📖 Processing batch {i} / {len(batches)}') + self.accelerator.wait_for_everyone() batch_output = self._process_function(reader, encoder, batch, experiment_settings, i) + self.accelerator.wait_for_everyone() for filename, encoded_output in batch_output.items(): torch.save( encoded_output, @@ -78,6 +80,7 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings + '.pt' ), ) + self.accelerator.wait_for_everyone() except Exception as exc: logger.error(f'Error reading file: {exc}')