Skip to content

Commit

Permalink
joined projection
Browse files Browse the repository at this point in the history
  • Loading branch information
Elisei Rykov committed Oct 1, 2024
1 parent 5e395d1 commit 1bbc689
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 42 deletions.
20 changes: 12 additions & 8 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,16 @@ def _read_modalities(self, record):

modality_messages: list[MultimodalMessage] = [m for m in record['messages'] if m.modality_object_path]

messages_to_delete = len(modality_messages) - modality_messages_after_truncation
# print("🙂"*10, modality_messages)

if self._truncate_top:
modality_messages = modality_messages[messages_to_delete:]
else:
modality_messages = modality_messages[:modality_messages_after_truncation]
# messages_to_delete = len(modality_messages) - modality_messages_after_truncation

# if self._truncate_top:
# modality_messages = modality_messages[messages_to_delete:]
# else:
# modality_messages = modality_messages[:modality_messages_after_truncation]

# print("🙂"*10, modality_messages, len(modality_messages))

modality_encodings: list[tuple[Modality, torch.Tensor]] = []
try:
Expand All @@ -194,9 +198,9 @@ def _read_modalities(self, record):

# record['modality_inputs'] = modality_encodings

if len(modality_encodings) != modality_messages_after_truncation:
# print("😆"*10, "len(modality_encodings) != modality_messages_after_truncation")
return None
# if len(modality_encodings) != modality_messages_after_truncation:
# print("😆"*10, f"len(modality_encodings) != modality_messages_after_truncation {modality_messages_after_truncation} {len(modality_encodings)}")
# return None

return modality_encodings

Expand Down
7 changes: 5 additions & 2 deletions turbo_alignment/modeling/multimodal/lm/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def convert_inputs_to_embeds(
replica_spans = self.get_replica_spans(sample_input_ids)
filtered_replica_spans = self.filter_replica_spans_with_modality(replica_spans, modality_spans)

# print("😆"*10, f"{len(modality_spans)=} {len(replica_spans)=} {len(sample_modality_inputs)=}")

assert len(modality_spans) == len(sample_modality_inputs)

grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list)
Expand All @@ -121,7 +123,7 @@ def convert_inputs_to_embeds(
).to(self.language_model.device)

# Encode modalities and insert into input embeds
for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items():
for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items(): # FIXME: works only with image modality
modality_encoder_input_indexes, modality_encoder_inputs = zip(*modality_encoder_inputs_with_indices)

if self.language_model.dtype == torch.float32:
Expand All @@ -134,10 +136,11 @@ def convert_inputs_to_embeds(
)

modality_replica_lm_input_embeds = pad_sequence(
[sample_lm_input_embeds[span[0] : span[1]] for span in filtered_replica_spans],
[sample_lm_input_embeds[replica_span[0] : modality_span[0].start] for replica_span, modality_span in zip(filtered_replica_spans, modality_spans)],
padding_value=0,
batch_first=True,
)

modality_encoder_embeddings = self.modality_adapters[modality](
encoded_modality_object_batch, modality_replica_lm_input_embeds
)
Expand Down
93 changes: 61 additions & 32 deletions turbo_alignment/modeling/multimodal/projectors/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,44 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
# return weighted_image_features


@MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA)
class LlavaWithTextMultiModalProjector(torch.nn.Module):
def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.text_hidden_size = text_hidden_size
self.k = n_modality_embs # Number of top patches to select
self.encoder_to_text_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size)
self.text_to_text_projection = torch.nn.Linear(text_hidden_size, text_hidden_size)
self.cross_attention = torch.nn.MultiheadAttention(embed_dim=text_hidden_size, num_heads=8)
self.output_layer = torch.nn.Linear(text_hidden_size, text_hidden_size)
# @MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA)
# class LlavaWithTextMultiModalProjector(torch.nn.Module):
# def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
# super().__init__()
# self.encoder_hidden_size = encoder_hidden_size
# self.text_hidden_size = text_hidden_size
# self.k = n_modality_embs # Number of top patches to select
# self.encoder_to_text_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size)
# self.text_to_text_projection = torch.nn.Linear(text_hidden_size, text_hidden_size)
# self.cross_attention = torch.nn.MultiheadAttention(embed_dim=text_hidden_size, num_heads=8)
# self.output_layer = torch.nn.Linear(text_hidden_size, text_hidden_size)

def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
# Project the image features into the text hidden space
projected_image = self.encoder_to_text_projection(image_features)
projected_text = self.text_to_text_projection(text_features)
# def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
# # Project the image features into the text hidden space
# projected_image = self.encoder_to_text_projection(image_features)
# projected_text = self.text_to_text_projection(text_features)

# Permute dimensions for attention
permuted_projected_image = projected_image.permute(1, 0, 2) # [image_patches, batch_size, hidden_dim]
permuted_projected_text = projected_text.permute(1, 0, 2) # [textual_tokens, batch_size, hidden_dim]
# # Permute dimensions for attention
# permuted_projected_image = projected_image.permute(1, 0, 2) # [image_patches, batch_size, hidden_dim]
# permuted_projected_text = projected_text.permute(1, 0, 2) # [textual_tokens, batch_size, hidden_dim]

# Cross-attention: text tokens attend to image patches
_, attention_weights = self.cross_attention(
query=permuted_projected_text, # Text queries attend to image patches
key=permuted_projected_image,
value=permuted_projected_image
)
# # Cross-attention: text tokens attend to image patches
# _, attention_weights = self.cross_attention(
# query=permuted_projected_text, # Text queries attend to image patches
# key=permuted_projected_image,
# value=permuted_projected_image
# )

# Average attention weights over text tokens to get importance scores for image patches
avg_attention_weights = attention_weights.mean(dim=1) # [batch_size, image_patches]
# # Average attention weights over text tokens to get importance scores for image patches
# avg_attention_weights = attention_weights.mean(dim=1) # [batch_size, image_patches]

# Select top-k patches based on attention scores
_, topk_indices = torch.topk(avg_attention_weights, self.k, dim=1) # [batch_size, k]
topk_image_patches = projected_image.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, projected_image.size(-1))) # [batch_size, k, hidden_dim]
# # Select top-k patches based on attention scores
# _, topk_indices = torch.topk(avg_attention_weights, self.k, dim=1) # [batch_size, k]
# topk_image_patches = projected_image.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, projected_image.size(-1))) # [batch_size, k, hidden_dim]

# Map the top-k patches into the LM embedding space
topk_mapped_patches = self.output_layer(topk_image_patches) # [batch_size, k, text_hidden_size]
return topk_mapped_patches # Output: [batch_size, k, lm_dim]
# # Map the top-k patches into the LM embedding space
# topk_mapped_patches = self.output_layer(topk_image_patches) # [batch_size, k, text_hidden_size]
# return topk_mapped_patches # Output: [batch_size, k, lm_dim]


# @MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA)
Expand Down Expand Up @@ -126,3 +126,32 @@ def forward(self, image_features: torch.Tensor, text_features: torch.Tensor) ->

# mapped_attentioned_values = self.output_layer(attention_values)
# return mapped_attentioned_values


@MultiModalProjectorRegistry.register(ModalityProjectorType.LLAVA_WITH_REPLICA)
class LlavaWithTextMultiModalProjector(torch.nn.Module):
def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
super().__init__()
self.cross_attention = torch.nn.MultiheadAttention(embed_dim=text_hidden_size, num_heads=8)
self.image_projection = torch.nn.Linear(encoder_hidden_size, text_hidden_size)
self.text_projection = torch.nn.Linear(text_hidden_size, text_hidden_size)
self.final_projection = torch.nn.Linear(text_hidden_size, text_hidden_size)

def forward(self, image_features, text_features):
projected_image = self.image_projection(image_features)
# projected_text = self.text_projection(text_features)

image_patches = projected_image.transpose(0, 1)
# text_tokens = projected_text.transpose(0, 1)
text_tokens = text_features.transpose(0, 1)

_, attention_weights = self.cross_attention(query=text_tokens,
key=image_patches,
value=image_patches)

patch_importance = attention_weights.mean(dim=1)

attended_image_features = projected_image * patch_importance.unsqueeze(-1)

# return self.final_projection(attended_image_features.sum(1).unsqueeze(1))
return attended_image_features.sum(1).unsqueeze(1)

0 comments on commit 1bbc689

Please sign in to comment.