From 1bbc6890955b8f051326e5068eb2354050e1b901 Mon Sep 17 00:00:00 2001 From: Elisei Rykov Date: Tue, 1 Oct 2024 15:43:13 +0300 Subject: [PATCH] joined projection --- .../dataset/multimodal/multimodal.py | 20 ++-- .../modeling/multimodal/lm/projection.py | 7 +- .../modeling/multimodal/projectors/llava.py | 93 ++++++++++++------- 3 files changed, 78 insertions(+), 42 deletions(-) diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index 77e83af..56cfe5f 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -174,12 +174,16 @@ def _read_modalities(self, record): modality_messages: list[MultimodalMessage] = [m for m in record['messages'] if m.modality_object_path] - messages_to_delete = len(modality_messages) - modality_messages_after_truncation + # print("🙂"*10, modality_messages) - if self._truncate_top: - modality_messages = modality_messages[messages_to_delete:] - else: - modality_messages = modality_messages[:modality_messages_after_truncation] + # messages_to_delete = len(modality_messages) - modality_messages_after_truncation + + # if self._truncate_top: + # modality_messages = modality_messages[messages_to_delete:] + # else: + # modality_messages = modality_messages[:modality_messages_after_truncation] + + # print("🙂"*10, modality_messages, len(modality_messages)) modality_encodings: list[tuple[Modality, torch.Tensor]] = [] try: @@ -194,9 +198,9 @@ def _read_modalities(self, record): # record['modality_inputs'] = modality_encodings - if len(modality_encodings) != modality_messages_after_truncation: - # print("😆"*10, "len(modality_encodings) != modality_messages_after_truncation") - return None + # if len(modality_encodings) != modality_messages_after_truncation: + # print("😆"*10, f"len(modality_encodings) != modality_messages_after_truncation {modality_messages_after_truncation} {len(modality_encodings)}") + # return None return modality_encodings diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py index 5dc962d..32f4304 100644 --- a/turbo_alignment/modeling/multimodal/lm/projection.py +++ b/turbo_alignment/modeling/multimodal/lm/projection.py @@ -107,6 +107,8 @@ def convert_inputs_to_embeds( replica_spans = self.get_replica_spans(sample_input_ids) filtered_replica_spans = self.filter_replica_spans_with_modality(replica_spans, modality_spans) + # print("😆"*10, f"{len(modality_spans)=} {len(replica_spans)=} {len(sample_modality_inputs)=}") + assert len(modality_spans) == len(sample_modality_inputs) grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list) @@ -121,7 +123,7 @@ def convert_inputs_to_embeds( ).to(self.language_model.device) # Encode modalities and insert into input embeds - for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items(): + for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items(): # FIXME: works only with image modality modality_encoder_input_indexes, modality_encoder_inputs = zip(*modality_encoder_inputs_with_indices) if self.language_model.dtype == torch.float32: @@ -134,10 +136,11 @@ def convert_inputs_to_embeds( ) modality_replica_lm_input_embeds = pad_sequence( - [sample_lm_input_embeds[span[0] : span[1]] for span in filtered_replica_spans], + [sample_lm_input_embeds[replica_span[0] : modality_span[0].start] for replica_span, modality_span in zip(filtered_replica_spans, modality_spans)], padding_value=0, batch_first=True, ) + modality_encoder_embeddings = self.modality_adapters[modality]( encoded_modality_object_batch, modality_replica_lm_input_embeds ) diff --git a/turbo_alignment/modeling/multimodal/projectors/llava.py b/turbo_alignment/modeling/multimodal/projectors/llava.py index 9230c69..e2ceb6b 100644 --- a/turbo_alignment/modeling/multimodal/projectors/llava.py +++ b/turbo_alignment/modeling/multimodal/projectors/llava.py @@ -55,44 +55,44 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: # return weighted_image_features -@MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA) -class LlavaWithTextMultiModalProjector(torch.nn.Module): - def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int): - super().__init__() - self.encoder_hidden_size = encoder_hidden_size - self.text_hidden_size = text_hidden_size - self.k = n_modality_embs # Number of top patches to select - self.encoder_to_text_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size) - self.text_to_text_projection = torch.nn.Linear(text_hidden_size, text_hidden_size) - self.cross_attention = torch.nn.MultiheadAttention(embed_dim=text_hidden_size, num_heads=8) - self.output_layer = torch.nn.Linear(text_hidden_size, text_hidden_size) +# @MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA) +# class LlavaWithTextMultiModalProjector(torch.nn.Module): +# def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int): +# super().__init__() +# self.encoder_hidden_size = encoder_hidden_size +# self.text_hidden_size = text_hidden_size +# self.k = n_modality_embs # Number of top patches to select +# self.encoder_to_text_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size) +# self.text_to_text_projection = torch.nn.Linear(text_hidden_size, text_hidden_size) +# self.cross_attention = torch.nn.MultiheadAttention(embed_dim=text_hidden_size, num_heads=8) +# self.output_layer = torch.nn.Linear(text_hidden_size, text_hidden_size) - def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor: - # Project the image features into the text hidden space - projected_image = self.encoder_to_text_projection(image_features) - projected_text = self.text_to_text_projection(text_features) +# def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor: +# # Project the image features into the text hidden space +# projected_image = self.encoder_to_text_projection(image_features) +# projected_text = self.text_to_text_projection(text_features) - # Permute dimensions for attention - permuted_projected_image = projected_image.permute(1, 0, 2) # [image_patches, batch_size, hidden_dim] - permuted_projected_text = projected_text.permute(1, 0, 2) # [textual_tokens, batch_size, hidden_dim] +# # Permute dimensions for attention +# permuted_projected_image = projected_image.permute(1, 0, 2) # [image_patches, batch_size, hidden_dim] +# permuted_projected_text = projected_text.permute(1, 0, 2) # [textual_tokens, batch_size, hidden_dim] - # Cross-attention: text tokens attend to image patches - _, attention_weights = self.cross_attention( - query=permuted_projected_text, # Text queries attend to image patches - key=permuted_projected_image, - value=permuted_projected_image - ) +# # Cross-attention: text tokens attend to image patches +# _, attention_weights = self.cross_attention( +# query=permuted_projected_text, # Text queries attend to image patches +# key=permuted_projected_image, +# value=permuted_projected_image +# ) - # Average attention weights over text tokens to get importance scores for image patches - avg_attention_weights = attention_weights.mean(dim=1) # [batch_size, image_patches] +# # Average attention weights over text tokens to get importance scores for image patches +# avg_attention_weights = attention_weights.mean(dim=1) # [batch_size, image_patches] - # Select top-k patches based on attention scores - _, topk_indices = torch.topk(avg_attention_weights, self.k, dim=1) # [batch_size, k] - topk_image_patches = projected_image.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, projected_image.size(-1))) # [batch_size, k, hidden_dim] +# # Select top-k patches based on attention scores +# _, topk_indices = torch.topk(avg_attention_weights, self.k, dim=1) # [batch_size, k] +# topk_image_patches = projected_image.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, projected_image.size(-1))) # [batch_size, k, hidden_dim] - # Map the top-k patches into the LM embedding space - topk_mapped_patches = self.output_layer(topk_image_patches) # [batch_size, k, text_hidden_size] - return topk_mapped_patches # Output: [batch_size, k, lm_dim] +# # Map the top-k patches into the LM embedding space +# topk_mapped_patches = self.output_layer(topk_image_patches) # [batch_size, k, text_hidden_size] +# return topk_mapped_patches # Output: [batch_size, k, lm_dim] # @MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA) @@ -126,3 +126,32 @@ def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) -> # mapped_attentioned_values = self.output_layer(attention_values) # return mapped_attentioned_values + + +@MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA) +class LlavaWithTextMultiModalProjector(torch.nn.Module): + def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int): + super().__init__() + self.cross_attention = torch.nn.MultiheadAttention(embed_dim=text_hidden_size, num_heads=8) + self.image_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size) + self.text_projection = torch.nn.Linear(text_hidden_size, text_hidden_size) + self.final_projection = torch.nn.Linear(text_hidden_size, text_hidden_size) + + def forward(self, image_features, text_features): + projected_image = self.image_projection(image_features) + # projected_text = self.text_projection(text_features) + + image_patches = projected_image.transpose(0, 1) + # text_tokens = projected_text.transpose(0, 1) + text_tokens = text_features.transpose(0, 1) + + _, attention_weights = self.cross_attention(query=text_tokens, + key=image_patches, + value=image_patches) + + patch_importance = attention_weights.mean(dim=1) + + attended_image_features = projected_image * patch_importance.unsqueeze(-1) + + # return self.final_projection(attended_image_features.sum(1).unsqueeze(1)) + return attended_image_features.sum(1).unsqueeze(1)