Skip to content

Commit

Permalink
debugging processer
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 11, 2024
1 parent a690bc1 commit fa40061
Showing 1 changed file with 122 additions and 97 deletions.
219 changes: 122 additions & 97 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from accelerate import Accelerator
import os
from accelerate.utils import gather_object
from allenai_common import Params
from safetensors.torch import save_file
Expand All @@ -18,108 +19,132 @@
from turbo_alignment.modeling.multimodal.encoders.base import BaseModalityEncoder
from turbo_alignment.pipelines.base import BaseStrategy
from turbo_alignment.settings.datasets.multimodal import (
MultimodalDatasetProcessingSettings,
MultimodalDatasetProcessingSettings,
)
from turbo_alignment.settings.modality import (
Modality,
ModalityEncoderSettings,
ModalityReaderSettings,
Modality,
ModalityEncoderSettings,
ModalityReaderSettings,
)

logger = get_project_logger()


class PreprocessMultimodalDatasetStrategy(BaseStrategy):
def __init__(self, *args, **kwargs):
self.accelerator = Accelerator()

def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
if self.accelerator.is_main_process:
logger.info(f'πŸ‘© Start dataset processing with the following settings:\n{experiment_settings}')

reader, encoder = self._load_modality_reader_encoder(
experiment_settings.reader_settings,
experiment_settings.encoder_settings,
experiment_settings.modality,
)
self._read_modality_objects(reader, encoder, experiment_settings)

if self.accelerator.is_main_process:
logger.info(f'πŸ‘© Saved!')

def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx):
modality_objects = []
for file_path in batch_file_paths:
modality_objects.append(reader.read(str(file_path)))
modality_objects = torch.cat(modality_objects)
encoded_modality_objects = encoder.encode(modality_objects.to(self.accelerator.device)).detach().cpu()
safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths)

return safetensors_dict_batch

def _async_process_files(self, reader, encoder, files_paths, experiment_settings):
batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size)

self.accelerator.wait_for_everyone()

for i, batches in enumerate(tqdm(batches_all)):
if self.accelerator.is_main_process:
logger.info(f'πŸ“– Processing batch {i} / {len(batches_all)}')
with self.accelerator.split_between_processes(batches) as batch:
try:
batch_output = self._process_function(reader, encoder, batch, experiment_settings, i)
except Exception as exc:
logger.error(f'Error reading file: {exc}')
self.accelerator.wait_for_everyone()

all_filenames = gather_object([filename for filename in batch_output.keys()])
all_tensors = gather_object([tensor for tensor in batch_output.values()])

if self.accelerator.is_main_process:
batch_output_from_all_processes = {f: t for f, t in zip(all_filenames, all_tensors)}
for filename, encoded_output in batch_output_from_all_processes.items():
filepath = experiment_settings.output_file_path / (
filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
)

with open(filepath, 'wb') as f:
torch.save(encoded_output, f)

def _load_modality_reader_encoder(
self,
reader_settings: ModalityReaderSettings,
encoder_settings: ModalityEncoderSettings,
modality: Modality,
) -> Tuple[BaseModalityReader, BaseModalityEncoder]:
device = self.accelerator.device
reader = ModalityReaderRegistry.by_name(modality).from_params(
Params({'type': reader_settings.reader_type, 'reader_path': reader_settings.reader_path})
)
encoder = ModalityEncoderRegistry.by_name(encoder_settings.modality_encoder_type)(
encoder_path=encoder_settings.encoder_path
).to(device)
return (reader, encoder)

def _read_modality_objects(self, reader, encoder, experiment_settings):
modality_tensors = []

available_extensions = ('jpg', 'jpeg', 'png', 'svg')

logger.info('πŸ“– Reading modality objects...')
files_paths: list[Path] = []
for extension in available_extensions:
files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}'))

self._async_process_files(reader, encoder, files_paths, experiment_settings)

@staticmethod
def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths):
tensors = {}
for file, tensor in zip(encoded_file_paths, encoded_modality_tensors):
tensors[file.name] = tensor.detach()
return tensors
def __init__(self, *args, **kwargs):
self.accelerator = Accelerator()

def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
if self.accelerator.is_main_process:
logger.info(f'πŸ‘© Start dataset processing with the following settings:\n{experiment_settings}')

reader, encoder = self._load_modality_reader_encoder(
experiment_settings.reader_settings,
experiment_settings.encoder_settings,
experiment_settings.modality,
)
self._read_modality_objects(reader, encoder, experiment_settings)

if self.accelerator.is_main_process:
logger.info(f'πŸ‘© Saved!')

def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx):
modality_objects = []
for file_path in batch_file_paths:
modality_objects.append(reader.read(str(file_path)))
modality_objects = torch.cat(modality_objects)
encoded_modality_objects = encoder.encode(modality_objects.to(self.accelerator.device)).detach().cpu()
logger.info(f'encoded a batch: {encoded_modality_objects.shape}')
safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths)

return safetensors_dict_batch


def _process_files(self, reader, encoder, files_paths, experiment_settings):
logger.info(f'πŸ‘© Processing without accelerate!')
batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size)

for i, batches in enumerate(tqdm(batches_all)):
logger.info(f'πŸ“– Processing batch {i} / {len(batches_all)}')
try:
batch_output = self._process_function(reader, encoder, batches, experiment_settings, i)
for filename, encoded_output in batch_output.items():
filepath = experiment_settings.output_file_path / (
filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
)
logger.info(f'πŸ“– Processing file {filepath}, {encoded_output.shape}')
torch.save(encoded_output, filepath)
logger.info(f'πŸ“– Saved')
except Exception as exc:
logger.error(f'Error reading file: {exc}')


def _async_process_files(self, reader, encoder, files_paths, experiment_settings):
logger.info(f'πŸ‘© Processing with accelerate!')
batches_all = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size)

self.accelerator.wait_for_everyone()

with self.accelerator.split_between_processes(batches_all) as batches:
for i, batch in enumerate(tqdm(batches)):
try:
logger.info(f'πŸ“– Processing batch {i} / {len(batches)}')
batch_output = self._process_function(reader, encoder, batch, experiment_settings, i)
for filename, encoded_output in batch_output.items():
torch.save(
encoded_output,
experiment_settings.output_file_path
/ (
filename
+ '.'
+ experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.pt'
),
)
except Exception as exc:
logger.error(f'Error reading file: {exc}')

def _load_modality_reader_encoder(
self,
reader_settings: ModalityReaderSettings,
encoder_settings: ModalityEncoderSettings,
modality: Modality,
) -> Tuple[BaseModalityReader, BaseModalityEncoder]:
device = self.accelerator.device
reader = ModalityReaderRegistry.by_name(modality).from_params(
Params({'type': reader_settings.reader_type, 'reader_path': reader_settings.reader_path})
)
encoder = ModalityEncoderRegistry.by_name(encoder_settings.modality_encoder_type)(
encoder_path=encoder_settings.encoder_path
).to(device)
return (reader, encoder)

def _read_modality_objects(self, reader, encoder, experiment_settings):
modality_tensors = []

available_extensions = ('jpg', 'jpeg', 'png', 'svg')

if self.accelerator.is_main_process:
logger.info('πŸ“– Reading modality objects...')
files_paths: list[Path] = []
for extension in available_extensions:
files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}'))

if os.environ.get('ACCELERATE_ENABLED', 'false') == 'true':
self._async_process_files(reader, encoder, files_paths, experiment_settings)
else:
self._process_files(reader, encoder, files_paths, experiment_settings)

@staticmethod
def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths):
tensors = {}
for file, tensor in zip(encoded_file_paths, encoded_modality_tensors):
tensors[file.name] = tensor.detach()
return tensors

0 comments on commit fa40061

Please sign in to comment.