Skip to content

Commit

Permalink
accelerate parallel processing
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 11, 2024
1 parent a2e21d6 commit b12af8c
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import gather_object
from allenai_common import Params
from safetensors.torch import save_file
from tqdm import tqdm
Expand Down Expand Up @@ -34,7 +35,8 @@ def __init__(self, *args, **kwargs):
self.accelerator = Accelerator()

def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
logger.info(f'πŸ‘© Start dataset processing with the following settings:\n{experiment_settings}')
if self.accelerator.is_main_process:
logger.info(f'πŸ‘© Start dataset processing with the following settings:\n{experiment_settings}')

reader, encoder = self._load_modality_reader_encoder(
experiment_settings.reader_settings,
Expand All @@ -43,7 +45,8 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
)
self._read_modality_objects(reader, encoder, experiment_settings)

logger.info(f'πŸ‘© Saved!')
if self.accelerator.is_main_process:
logger.info(f'πŸ‘© Saved!')

def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx):
modality_objects = []
Expand All @@ -60,24 +63,32 @@ def _async_process_files(self, reader, encoder, files_paths, experiment_settings

self.accelerator.wait_for_everyone()

with self.accelerator.split_between_processes(batches_all) as batches:
for i, batch in enumerate(tqdm(batches)):
for i, batches in enumerate(tqdm(batches_all)):
logger.info(f'πŸ“– Processing batch {i} / {len(batches_all)}')
with self.accelerator.split_between_processes(batches) as batch:
try:
logger.info(f'πŸ“– Processing batch {i} / {len(batches)}')
batch_output = self._process_function(reader, encoder, batch, experiment_settings, i)
for filename, encoded_output in batch_output.items():
filepath = experiment_settings.output_file_path / (
filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
)
with open(filepath, 'wb') as f:
torch.save(encoded_output, f)
except Exception as exc:
logger.error(f'Error reading file: {exc}')
self.accelerator.wait_for_everyone()

all_filenames = gather_object([filename for filename in batch_output.keys()])
all_tensors = gather_object([tensor for tensor in batch_output.values()])

if self.accelerator.is_main_process:
batch_output_from_all_processes = {f: t for f, t in zip(all_filenames, all_tensors)}
for filename, encoded_output in batch_output_from_all_processes.items():
filepath = experiment_settings.output_file_path / (
filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
)

with open(filepath, 'wb') as f:
torch.save(encoded_output, f)

def _load_modality_reader_encoder(
self,
Expand Down

0 comments on commit b12af8c

Please sign in to comment.