Skip to content

Commit

Permalink
fixed error with device
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 10, 2024
1 parent dd02fe6 commit 3c718c1
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _load_modality_reader_encoder(
).to(device)
return (reader, encoder)

def _read_modality_objects(self, reader, encoder, experiment_settings):
def _read_modality_objects(self, reader, encoder, experiment_settings, accelerator):
modality_tensors = []

available_extensions = ('jpg', 'jpeg', 'png', 'svg')
Expand All @@ -54,34 +54,34 @@ def _read_modality_objects(self, reader, encoder, experiment_settings):
for extension in available_extensions:
files_paths.extend(experiment_settings.dataset_path.glob(f'*.{extension}'))

safetensors_full_dict = self._async_process_files(reader, encoder, files_paths, experiment_settings)
safetensors_full_dict = self._async_process_files(reader, encoder, files_paths, experiment_settings, accelerator)
return safetensors_full_dict

@staticmethod
def _build_encoder_config(encoder) -> dict:
return {'emb_dim': encoder.emb_dim}

def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx):
def _process_function(self, reader, encoder, batch_file_paths, experiment_settings, batch_idx, accelerator):
logger.info(f'📖 Processing batch {batch_idx}')
modality_objects = []
for file_path in batch_file_paths:
modality_objects.append(reader.read(str(file_path)))
modality_objects = torch.cat(modality_objects)
encoded_modality_objects = encoder.encode(modality_objects).detach().cpu()
encoded_modality_objects = encoder.encode(modality_objects.to(accelerator.device)).detach().cpu()
safetensors_dict_batch = self._get_safetensor_dict(encoded_modality_objects, batch_file_paths)
logger.info(f'📖 Finish with batch {batch_idx}')

return safetensors_dict_batch

def _async_process_files(self, reader, encoder, files_paths, experiment_settings):
def _async_process_files(self, reader, encoder, files_paths, experiment_settings, accelerator):
modality_tensors = [None] * len(files_paths)

batches = np.array_split(files_paths, len(files_paths) // experiment_settings.batch_size)
safetensors_full_dict = {}

with ThreadPoolExecutor() as executor:
futures = {
executor.submit(self._process_function, reader, encoder, batch, experiment_settings, i): i
executor.submit(self._process_function, reader, encoder, batch, experiment_settings, i, accelerator): i
for i, batch in enumerate(batches)
}
for i, future in enumerate(as_completed(futures)):
Expand All @@ -108,7 +108,7 @@ def run(self, experiment_settings: MultimodalDatasetProcessingSettings) -> None:
experiment_settings.modality,
accelerator,
)
safetensors_full_dict = self._read_modality_objects(reader, encoder, experiment_settings)
safetensors_full_dict = self._read_modality_objects(reader, encoder, experiment_settings, accelerator)

logger.info(f'👩 Saving safetensors file...')
save_file(
Expand Down

0 comments on commit 3c718c1

Please sign in to comment.