Skip to content

Commit

Permalink
attention pooling with top k selection
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Sep 27, 2024
1 parent cc3a444 commit 18b4ca8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
2 changes: 1 addition & 1 deletion turbo_alignment/modeling/multimodal/projectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from turbo_alignment.modeling.multimodal.projectors.attention_pooling import (
AttentionPoolingMultiModalProjector,
AttentionPoolingMultiModalProjector, TopKAttentionPoolingMultiModalProjector
)
from turbo_alignment.modeling.multimodal.projectors.c_abstractor import CAbstractor
from turbo_alignment.modeling.multimodal.projectors.llava import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,20 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
attention_scores = torch.softmax(self.attention_scores(projected_features), 1)
pooled_output = torch.sum(projected_features * attention_scores, dim=1)
return pooled_output

@MultiModalProjectorRegistry.register(ModalityProjectorType.TOP_K_ATTENTION_POOLING)
class TopKAttentionPoolingMultiModalProjector(torch.nn.Module):
def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.text_hidden_size = text_hidden_size
self.n_modality_embs = n_modality_embs
self.linear_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size)
self.attention_scores = torch.nn.Linear(text_hidden_size, 1)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
projected_features = self.linear_projection(image_features) # map each image patch to the language model dimension
attention_scores = torch.softmax(self.attention_scores(projected_features), 1) # calculate learnable attention scores for each patch
top_indices = torch.topk(attention_scores.squeeze(-1), k=self.n_modality_embs, dim=1).indices # select indices top N patches according to attention scores
top_k_hidden_states = torch.gather(projected_features, index=top_indices.unsqueeze(-1).expand(-1, -1, projected_features.size(-1)), dim=1) # select top patches
return top_k_hidden_states
1 change: 1 addition & 0 deletions turbo_alignment/settings/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ class ModalityProjectorType(str, Enum):
LLAVA = 'llava'
C_ABSTRACTOR = 'c_abstractor'
ATTENTION_POOLING = 'attention_pooling'
TOP_K_ATTENTION_POOLING = 'top_k_attention_pooling'

0 comments on commit 18b4ca8

Please sign in to comment.