From 18b4ca8cb2209389be0455b5a9a366f330bd909b Mon Sep 17 00:00:00 2001 From: lmeribal Date: Fri, 27 Sep 2024 08:34:10 +0000 Subject: [PATCH] attention pooling with top k selection --- .../modeling/multimodal/projectors/__init__.py | 2 +- .../multimodal/projectors/attention_pooling.py | 17 +++++++++++++++++ turbo_alignment/settings/modality.py | 1 + 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/turbo_alignment/modeling/multimodal/projectors/__init__.py b/turbo_alignment/modeling/multimodal/projectors/__init__.py index fc40d19..4410853 100644 --- a/turbo_alignment/modeling/multimodal/projectors/__init__.py +++ b/turbo_alignment/modeling/multimodal/projectors/__init__.py @@ -1,5 +1,5 @@ from turbo_alignment.modeling.multimodal.projectors.attention_pooling import ( - AttentionPoolingMultiModalProjector, + AttentionPoolingMultiModalProjector, TopKAttentionPoolingMultiModalProjector ) from turbo_alignment.modeling.multimodal.projectors.c_abstractor import CAbstractor from turbo_alignment.modeling.multimodal.projectors.llava import ( diff --git a/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py b/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py index 03525d2..0ffa4e1 100644 --- a/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py +++ b/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py @@ -21,3 +21,20 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: attention_scores = torch.softmax(self.attention_scores(projected_features), 1) pooled_output = torch.sum(projected_features * attention_scores, dim=1) return pooled_output + +@MultiModalProjectorRegistry.register(ModalityProjectorType.TOP_K_ATTENTION_POOLING) +class TopKAttentionPoolingMultiModalProjector(torch.nn.Module): + def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int): + super().__init__() + self.encoder_hidden_size = encoder_hidden_size + self.text_hidden_size = text_hidden_size + self.n_modality_embs = n_modality_embs + self.linear_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size) + self.attention_scores = torch.nn.Linear(text_hidden_size, 1) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + projected_features = self.linear_projection(image_features) # map each image patch to the language model dimension + attention_scores = torch.softmax(self.attention_scores(projected_features), 1) # calculate learnable attention scores for each patch + top_indices = torch.topk(attention_scores.squeeze(-1), k=self.n_modality_embs, dim=1).indices # select indices top N patches according to attention scores + top_k_hidden_states = torch.gather(projected_features, index=top_indices.unsqueeze(-1).expand(-1, -1, projected_features.size(-1)), dim=1) # select top patches + return top_k_hidden_states diff --git a/turbo_alignment/settings/modality.py b/turbo_alignment/settings/modality.py index 3eaa380..5d53666 100755 --- a/turbo_alignment/settings/modality.py +++ b/turbo_alignment/settings/modality.py @@ -39,3 +39,4 @@ class ModalityProjectorType(str, Enum): LLAVA = 'llava' C_ABSTRACTOR = 'c_abstractor' ATTENTION_POOLING = 'attention_pooling' + TOP_K_ATTENTION_POOLING = 'top_k_attention_pooling'