Skip to content

Commit

Permalink
Batched processing without .cat
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 10, 2024
1 parent 184e19c commit d9058a3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
51 changes: 51 additions & 0 deletions turbo_alignment/modeling/multimodal/encoders/image/siglip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pathlib import Path
from typing import Optional

import torch
from transformers import CLIPModel

from turbo_alignment.modeling.multimodal.encoders.image.base import BaseImageEncoder
from turbo_alignment.modeling.multimodal.encoders.registry import (
ModalityEncoderRegistry,
)
from turbo_alignment.settings.modality import ModalityEncoderType


@ModalityEncoderRegistry.register(ModalityEncoderType.CLIP)
class CLIPImageModeling(BaseImageEncoder):
def __init__(self, encoder_path: Path, model_clip: Optional[CLIPModel] = None, is_pickle: bool = False):
super().__init__()
if model_clip is not None:
self.model_clip = model_clip
else:
self.model_clip = CLIPModel.from_pretrained(encoder_path)
self.is_pickle = is_pickle

@staticmethod
def _get_clip_hidden_states(model_clip: CLIPModel, inputs: torch.Tensor, is_pickle: bool = False) -> torch.Tensor:
if is_pickle:
return inputs
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py#L213
# -2 is default value of vision_feature_layer in llava config
# [1:] is everything after vit [cls] token
return model_clip.vision_model(inputs.squeeze(1), output_hidden_states=True).hidden_states[-2][
:, 1:
] # FIXME: squeeze dimension?

def encode(self, inputs: torch.Tensor) -> torch.Tensor:
return self._get_clip_hidden_states(self.model_clip, inputs, self.is_pickle)

@property
def emb_dim(self):
return self.model_clip.config.vision_config.hidden_size

@property
def device(self):
return self.model_clip.device

@property
def n_modality_embs(self) -> int:
image_size = self.model_clip.config.vision_config.image_size
dummy_pixel_values = torch.empty(1, 3, image_size, image_size)
hidden_states = self._get_clip_hidden_states(self.model_clip, dummy_pixel_values, is_pickle=False)
return hidden_states.shape[1]
8 changes: 3 additions & 5 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def _read_modality_objects(self, reader: BaseModalityReader, dataset_path: Path)
files_paths.extend(dataset_path.glob(f'*.{extension}'))

modality_tensors = self._async_process_files(reader, files_paths, total_number_of_objects)

modality_tensors = torch.cat(modality_tensors)
return modality_tensors, files_paths

@staticmethod
Expand All @@ -63,11 +61,11 @@ def _encode_modality_objects(
encoded_modality_tensors = []

logger.info('πŸ‘©β€πŸ’» Encoding objects...')
batched_modality_tensors = modality_tensors.split(batch_size)
for i, batch in enumerate(batched_modality_tensors):
logger.info(f'πŸ‘©β€πŸ’» Encoded {i} / {len(batched_modality_tensors)} batches')
for i in range(0, len(modality_tensors), batch_size):
batch = torch.cat(modality_tensors[i:i + batch_size], dim=0)
encoded_modality_tensor_batch = encoder.encode(batch.to(encoder.device)).detach().cpu()
encoded_modality_tensors.append(encoded_modality_tensor_batch)
logger.info(f'πŸ‘©β€πŸ’» Encoded {i} / {len(modality_tensors)} tensors')

encoded_modality_tensors = torch.cat(encoded_modality_tensors)

Expand Down

0 comments on commit d9058a3

Please sign in to comment.