Skip to content

Commit

Permalink
adapter with top patches selector
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 27, 2024
1 parent 18b4ca8 commit b91da3e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 6 deletions.
2 changes: 1 addition & 1 deletion turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import math
from abc import ABC
from pathlib import Path
Expand All @@ -6,7 +7,6 @@
import numpy as np
import torch
from allenai_common import Params
import gc

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.data.multimodal import BaseModalityReader
Expand Down
3 changes: 2 additions & 1 deletion turbo_alignment/modeling/multimodal/projectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from turbo_alignment.modeling.multimodal.projectors.attention_pooling import (
AttentionPoolingMultiModalProjector, TopKAttentionPoolingMultiModalProjector
AttentionPoolingMultiModalProjector,
TopKAttentionPoolingMultiModalProjector,
)
from turbo_alignment.modeling.multimodal.projectors.c_abstractor import CAbstractor
from turbo_alignment.modeling.multimodal.projectors.llava import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
pooled_output = torch.sum(projected_features * attention_scores, dim=1)
return pooled_output


@MultiModalProjectorRegistry.register(ModalityProjectorType.TOP_K_ATTENTION_POOLING)
class TopKAttentionPoolingMultiModalProjector(torch.nn.Module):
def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
Expand All @@ -33,8 +34,41 @@ def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_e
self.attention_scores = torch.nn.Linear(text_hidden_size, 1)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
projected_features = self.linear_projection(image_features) # map each image patch to the language model dimension
attention_scores = torch.softmax(self.attention_scores(projected_features), 1) # calculate learnable attention scores for each patch
top_indices = torch.topk(attention_scores.squeeze(-1), k=self.n_modality_embs, dim=1).indices # select indices top N patches according to attention scores
top_k_hidden_states = torch.gather(projected_features, index=top_indices.unsqueeze(-1).expand(-1, -1, projected_features.size(-1)), dim=1) # select top patches
projected_features = self.linear_projection(
image_features
) # map each image patch to the language model dimension
attention_scores = torch.softmax(
self.attention_scores(projected_features), 1
) # calculate learnable attention scores for each patch
top_indices = torch.topk(
attention_scores.squeeze(-1), k=self.n_modality_embs, dim=1
).indices # select indices top N patches according to attention scores
top_k_hidden_states = torch.gather(
projected_features, index=top_indices.unsqueeze(-1).expand(-1, -1, projected_features.size(-1)), dim=1
) # select top patches
return top_k_hidden_states


@MultiModalProjectorRegistry.register(ModalityProjectorType.THRESHOLD_SELECTOR)
class ThresholdSelectorMultiModalProjector(torch.nn.Module):
def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.text_hidden_size = text_hidden_size
self.n_modality_embs = n_modality_embs
self.linear_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size)
self.selection_score = torch.nn.Linear(text_hidden_size, 1)
self.threshold = 0.5

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
projected_features = self.linear_projection(
image_features
) # map each image patch to the language model dimension
selection_scores = torch.sigmoid(
self.selection_score(projected_features)
) # calculate learnable attention scores for each patch
selection_mask = selection_scores < self.threshold
projected_features[
:, selection_mask[0, :, 0]
] = 0 # set zeros for hiddens with attention score < threshold (just a test for PoC)
return projected_features
1 change: 1 addition & 0 deletions turbo_alignment/settings/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ class ModalityProjectorType(str, Enum):
C_ABSTRACTOR = 'c_abstractor'
ATTENTION_POOLING = 'attention_pooling'
TOP_K_ATTENTION_POOLING = 'top_k_attention_pooling'
THRESHOLD_SELECTOR = 'threshold_selector'

0 comments on commit b91da3e

Please sign in to comment.