From d9058a36dcbdfad69cd96f3a918a30fc0990306f Mon Sep 17 00:00:00 2001 From: lmeribal Date: Tue, 10 Sep 2024 06:30:08 +0000 Subject: [PATCH] Batched processing without .cat --- .../multimodal/encoders/image/siglip.py | 51 +++++++++++++++++++ .../pipelines/preprocessing/multimodal.py | 8 ++- 2 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 turbo_alignment/modeling/multimodal/encoders/image/siglip.py diff --git a/turbo_alignment/modeling/multimodal/encoders/image/siglip.py b/turbo_alignment/modeling/multimodal/encoders/image/siglip.py new file mode 100644 index 0000000..d803c17 --- /dev/null +++ b/turbo_alignment/modeling/multimodal/encoders/image/siglip.py @@ -0,0 +1,51 @@ +from pathlib import Path +from typing import Optional + +import torch +from transformers import CLIPModel + +from turbo_alignment.modeling.multimodal.encoders.image.base import BaseImageEncoder +from turbo_alignment.modeling.multimodal.encoders.registry import ( + ModalityEncoderRegistry, +) +from turbo_alignment.settings.modality import ModalityEncoderType + + +@ModalityEncoderRegistry.register(ModalityEncoderType.CLIP) +class CLIPImageModeling(BaseImageEncoder): + def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, is_pickle: bool = False): + super().__init__() + if model_clip is not None: + self.model_clip = model_clip + else: + self.model_clip = CLIPModel.from_pretrained(encoder_path) + self.is_pickle = is_pickle + + @staticmethod + def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor: + if is_pickle: + return inputs + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213 + # -2 is default value of vision_feature_layer in llava config + # [1:] is everything after vit [cls] token + return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][ + :, 1: + ] # FIXME: squeeze dimension? + + def encode(self, inputs: torch.Tensor) -> torch.Tensor: + return self._get_clip_hidden_states(self.model_clip, inputs, self.is_pickle) + + @property + def emb_dim(self): + return self.model_clip.config.vision_config.hidden_size + + @property + def device(self): + return self.model_clip.device + + @property + def n_modality_embs(self) -> int: + image_size = self.model_clip.config.vision_config.image_size + dummy_pixel_values = torch.empty(1, 3, image_size, image_size) + hidden_states = self._get_clip_hidden_states(self.model_clip, dummy_pixel_values, is_pickle=False) + return hidden_states.shape[1] diff --git a/turbo_alignment/pipelines/preprocessing/multimodal.py b/turbo_alignment/pipelines/preprocessing/multimodal.py index f85a65d..e6ecc16 100644 --- a/turbo_alignment/pipelines/preprocessing/multimodal.py +++ b/turbo_alignment/pipelines/preprocessing/multimodal.py @@ -52,8 +52,6 @@ def _read_modality_objects(self, reader: BaseModalityReader, dataset_path: Path) files_paths.extend(dataset_path.glob(f'*.{extension}')) modality_tensors = self._async_process_files(reader, files_paths, total_number_of_objects) - - modality_tensors = torch.cat(modality_tensors) return modality_tensors, files_paths @staticmethod @@ -63,11 +61,11 @@ def _encode_modality_objects( encoded_modality_tensors = [] logger.info('👩‍💻 Encoding objects...') - batched_modality_tensors = modality_tensors.split(batch_size) - for i, batch in enumerate(batched_modality_tensors): - logger.info(f'👩‍💻 Encoded {i} / {len(batched_modality_tensors)} batches') + for i in range(0, len(modality_tensors), batch_size): + batch = torch.cat(modality_tensors[i:i + batch_size], dim=0) encoded_modality_tensor_batch = encoder.encode(batch.to(encoder.device)).detach().cpu() encoded_modality_tensors.append(encoded_modality_tensor_batch) + logger.info(f'👩‍💻 Encoded {i} / {len(modality_tensors)} tensors') encoded_modality_tensors = torch.cat(encoded_modality_tensors)