Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 10, 2024
1 parent 2544a00 commit e13ddb5
Showing 1 changed file with 46 additions and 55 deletions.
101 changes: 46 additions & 55 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Tuple

import h5py
import numpy as np
import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -29,6 +30,50 @@


class PreprocessMultimodalDatasetStrategy(BaseStrategy):
def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}')
accelerator = Accelerator()

reader, encoder = self._load_modality_reader_encoder(
experiment_settings.reader_settings,
experiment_settings.encoder_settings,
experiment_settings.modality,
accelerator,
)
self._read_modality_objects(reader, encoder, experiment_settings, accelerator)

logger.info(f'👩 Saved!')

def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx, accelerator):
modality_objects = []
for file_path in batch_file_paths:
modality_objects.append(reader.read(str(file_path)))
modality_objects = torch.cat(modality_objects)
encoded_modality_objects = encoder.encode(modality_objects.to(accelerator.device)).detach().cpu()
safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths)

return safetensors_dict_batch

def _async_process_files(self, reader, encoder, files_paths, experiment_settings, accelerator):
output_file_name = experiment_settings.output_file_path / (
experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.h5'
)

batches = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size)

for i, batch in enumerate(tqdm(batches)):
try:
logger.info(f'📖 Processing batch {i} / {len(batches)}')
batch_output = self._process_function(reader, encoder, batch, experiment_settings, i, accelerator)
with h5py.File(output_file_name, 'a') as f:
for path, encoded_output in batch_output.items():
f.create_dataset(path, data=encoded_output.numpy())
except Exception as exc:
logger.error(f'Error reading file: {exc}')

@staticmethod
def _load_modality_reader_encoder(
reader_settings: ModalityReaderSettings,
Expand All @@ -55,69 +100,15 @@ def _read_modality_objects(self, reader, encoder, experiment_settings, accelerat
for extension in available_extensions:
files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}'))

safetensors_full_dict = self._async_process_files(
reader, encoder, files_paths, experiment_settings, accelerator
)
return safetensors_full_dict
self._async_process_files(reader, encoder, files_paths, experiment_settings, accelerator)

@staticmethod
def _build_encoder_config(encoder) -> dict:
return {'emb_dim': encoder.emb_dim}

def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx, accelerator):
logger.info(f'📖 Processing batch {batch_idx}')
modality_objects = []
for file_path in batch_file_paths:
modality_objects.append(reader.read(str(file_path)))
modality_objects = torch.cat(modality_objects)
encoded_modality_objects = encoder.encode(modality_objects.to(accelerator.device)).detach().cpu()
safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths)

return safetensors_dict_batch

def _async_process_files(self, reader, encoder, files_paths, experiment_settings, accelerator):
modality_tensors = [None] * len(files_paths)

print(experiment_settings.batch_size, len(files_paths))
batches = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size)

safetensors_full_dict = {}
for i, batch in enumerate(tqdm(batches)):
try:
batch_output = self._process_function(reader, encoder, batch, experiment_settings, i, accelerator)
safetensors_full_dict.update(batch_output)
except Exception as exc:
logger.error(f'Error reading file: {exc}')
return safetensors_full_dict

@staticmethod
def _get_safetensor_dict(encoded_modality_tensors, encoded_file_paths):
tensors = {}
for file, tensor in zip(encoded_file_paths, encoded_modality_tensors):
tensors[file.name] = tensor.detach()
return tensors

def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
logger.info(f'👩 Start dataset processing with the following settings:\n{experiment_settings}')
accelerator = Accelerator()

reader, encoder = self._load_modality_reader_encoder(
experiment_settings.reader_settings,
experiment_settings.encoder_settings,
experiment_settings.modality,
accelerator,
)
safetensors_full_dict = self._read_modality_objects(reader, encoder, experiment_settings, accelerator)

logger.info(f'👩 Saving safetensors file...')
save_file(
safetensors_full_dict,
experiment_settings.output_file_path
/ (
experiment_settings.modality.value
+ '.'
+ experiment_settings.encoder_settings.modality_encoder_type
+ '.safetensors'
),
)
logger.info(f'👩 Saved!')

0 comments on commit e13ddb5

Please sign in to comment.