Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeribal committed Oct 15, 2024
1 parent 77aeb56 commit 05a6db0
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 56 deletions.
2 changes: 1 addition & 1 deletion turbo_alignment/cherry_picks/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Iterable

import torch
import wandb
from PIL import Image
from transformers import PreTrainedTokenizerBase

import wandb
from turbo_alignment.cherry_picks.base import CherryPickCallbackBase
from turbo_alignment.dataset.multimodal import InferenceMultimodalDataset
from turbo_alignment.generators.multimodal import MultimodalGenerator
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/logging/weights_and_biases.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any

import wandb
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.settings.logging.weights_and_biases import WandbSettings


Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/common/tf/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import pandas as pd
import wandb
from clearml import Task
from transformers import (
TrainerCallback,
Expand All @@ -14,6 +13,7 @@
from wandb.sdk.lib.disabled import RunDisabled
from wandb.sdk.wandb_run import Run

import wandb
from turbo_alignment.common.logging import get_project_logger

logger = get_project_logger()
Expand Down
4 changes: 2 additions & 2 deletions turbo_alignment/dataset/multimodal/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ class DataCollatorWithModalityInputs(DataCollatorForTokenClassification):
def torch_call(self, features):
label_name = 'label' if 'label' in features[0].keys() else 'labels'
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
# print("😍"*10, features[0].keys())
if 'modality_inputs' in features[0].keys():
modality_inputs = [feature['modality_inputs'] for feature in features]
print([feature['modality_inputs'] for feature in features])
modality_inputs = torch.stack([torch.stack(feature['modality_inputs']) for feature in features])
else:
modality_inputs = [None for _ in features]

Expand Down
8 changes: 3 additions & 5 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
from turbo_alignment.common.registry import Params

from turbo_alignment.common.data.io import read_jsonl
from turbo_alignment.common.data.multimodal import BaseModalityReader
Expand Down Expand Up @@ -194,16 +193,14 @@ def _read_modalities(self, record):
try:
for msg in modality_messages:
reader = self._modality_readers[msg.type]
modality_encodings.append((msg.type, reader.read(msg.content)))
# modality_encodings.append((msg.type, reader.read(msg.content)))
modality_encodings.append(reader.read(msg.content))
except (OSError, RuntimeError, KeyError):
return None

# record['modality_inputs'] = modality_encodings

if len(modality_encodings) != modality_messages_after_truncation:
return None

# return record
return modality_encodings

def __iter__(self):
Expand All @@ -218,6 +215,7 @@ def __iter__(self):
end = min(start + per_worker, end)
for i, sample in enumerate(self.records[start:end]):
output = self._read_modalities(sample)

if output:
yield sample | {'modality_inputs': output}

Expand Down
48 changes: 29 additions & 19 deletions turbo_alignment/modeling/multimodal/lm/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,48 @@ def convert_inputs_to_embeds(
) # returns mask with ids of spans from 1 to N
modality_spans = find_objects(span_mask) # returns list of tuples with start index and end index

print(len(sample_modality_inputs), len(modality_spans))
# exit()

assert len(modality_spans) == len(sample_modality_inputs)

grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list)
# grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list)
grouped_modality_encoder_inputs = []

# Prepare modality batches
for index, modality_object in enumerate(sample_modality_inputs):
modality, inputs = modality_object
grouped_modality_encoder_inputs[modality].append((index, inputs))
# modality, inputs = modality_object
# grouped_modality_encoder_inputs[modality].append((index, inputs))
inputs = modality_object
grouped_modality_encoder_inputs.append((index, inputs))

sorted_modality_embeddings: torch.Tensor = torch.full(
(len(sample_modality_inputs), self.n_modality_embs, self.language_model_dim), torch.nan
).to(self.language_model.device)

# Encode modalities and insert into input embeds
for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items():
modality_encoder_input_indexes, modality_encoder_inputs = zip(*modality_encoder_inputs_with_indices)

if self.language_model.dtype == torch.float32:
encoded_modality_object_batch = self.encoders[modality].encode(
torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device)
)
else:
encoded_modality_object_batch = self.encoders[modality].encode(
torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device).bfloat16()
)

modality_encoder_embeddings = self.modality_adapters[modality](encoded_modality_object_batch)

sorted_modality_embeddings[modality_encoder_input_indexes, :] = modality_encoder_embeddings.to(
sorted_modality_embeddings.dtype
# for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items():
# for modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs:
modality_encoder_inputs_with_indices = grouped_modality_encoder_inputs
modality = Modality.IMAGE
modality_encoder_input_indexes, modality_encoder_inputs = zip(*modality_encoder_inputs_with_indices)
# modality_encoder_input_indexes, modality_encoder_inputs = modality_encoder_inputs_with_indices

if self.language_model.dtype == torch.float32:
encoded_modality_object_batch = self.encoders[modality].encode(
torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device)
)
else:
encoded_modality_object_batch = self.encoders[modality].encode(
torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device).bfloat16()
)

modality_encoder_embeddings = self.modality_adapters[modality](encoded_modality_object_batch)

sorted_modality_embeddings[modality_encoder_input_indexes, :] = modality_encoder_embeddings.to(
sorted_modality_embeddings.dtype
)

substituted_sample_lm_input_embeds = sample_lm_input_embeds.clone()
for i, modality_span in enumerate(modality_spans):
substituted_sample_lm_input_embeds[
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/modeling/multimodal/projectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from turbo_alignment.modeling.multimodal.projectors.attention_pooling import (
AttentionPoolingMultiModalProjector,
TopKAttentionPoolingMultiModalProjector,
TopKAttentionPoolingWithNHeadsMultiModalProjector
TopKAttentionPoolingWithNHeadsMultiModalProjector,
)
from turbo_alignment.modeling.multimodal.projectors.c_abstractor import CAbstractor
from turbo_alignment.modeling.multimodal.projectors.llava import (
Expand Down
52 changes: 29 additions & 23 deletions turbo_alignment/modeling/multimodal/projectors/attention_pooling.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import math

import numpy as np
import torch
import torch.nn.functional as F

from turbo_alignment.modeling.multimodal.projectors.registry import (
MultiModalProjectorRegistry,
)
from turbo_alignment.settings.modality import ModalityProjectorType
import torch.nn.functional as F
import math
import numpy as np


def get_abs_pos(abs_pos, tgt_size):
Expand All @@ -18,15 +19,21 @@ def get_abs_pos(abs_pos, tgt_size):
dtype = abs_pos.dtype

if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
return (
F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode='bicubic',
align_corners=False,
)
.permute(0, 2, 3, 1)
.flatten(0, 2)
.to(dtype=dtype)
)
else:
return abs_pos


# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/visual.py
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
Expand Down Expand Up @@ -54,7 +61,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


Expand All @@ -66,18 +73,19 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product

emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)

emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb


@MultiModalProjectorRegistry.register(ModalityProjectorType.ATTENTION_POOLING)
class AttentionPoolingMultiModalProjector(torch.nn.Module):
def __init__(self, encoder_hidden_size: int, text_hidden_size: int, n_modality_embs: int):
Expand Down Expand Up @@ -115,18 +123,18 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
projected_features = self.linear_projection(
image_features
) # map each image patch to the language model dimension

projected_features = projected_features + pos_embed

attention_scores = torch.softmax(
self.attention_scores(projected_features), 1
) # calculate learnable attention scores for each patch
top_indices = torch.topk(
attention_scores.squeeze(-1), k=attention_scores.shape[1], dim=1
).indices # select indices top N patches according to attention scores

projected_features[:, top_indices[:, self.top_k:].squeeze(0)] = 0 # set zero for unselected tokens
projected_features = projected_features[(projected_features != 0).any(dim=-1)] # remove zero vectors
projected_features[:, top_indices[:, self.top_k :].squeeze(0)] = 0 # set zero for unselected tokens
projected_features = projected_features[(projected_features != 0).any(dim=-1)] # remove zero vectors

return projected_features.unsqueeze(0)

Expand Down Expand Up @@ -156,17 +164,15 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
# projected_features = projected_features + pos_embed

scores = self.attention_scores(projected_features)
attention_scores = torch.softmax(
scores, 1
) # calculate learnable attention scores for each patch
attention_scores = torch.softmax(scores, 1) # calculate learnable attention scores for each patch
attention_scores = torch.max(attention_scores, -1).values
# attention_scores = torch.mean(attention_scores, -1)
top_indices = torch.topk(
attention_scores.squeeze(-1), k=attention_scores.shape[1], dim=1
).indices # select indices top N patches according to attention scores

projected_features[:, top_indices[:, self.top_k:].squeeze(0)] = 0 # set zero for unselected tokens
projected_features = projected_features[(projected_features != 0).any(dim=-1)] # remove zero vectors
projected_features[:, top_indices[:, self.top_k :].squeeze(0)] = 0 # set zero for unselected tokens
projected_features = projected_features[(projected_features != 0).any(dim=-1)] # remove zero vectors

return projected_features.unsqueeze(0)

Expand Down
3 changes: 1 addition & 2 deletions turbo_alignment/pipelines/preprocessing/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Tuple
Expand All @@ -6,8 +7,6 @@
import numpy as np
import torch
from accelerate import Accelerator
from turbo_alignment.common.registry import Params
import os
from accelerate.utils import gather_object
from safetensors.torch import save_file
from tqdm import tqdm
Expand Down
2 changes: 1 addition & 1 deletion turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from turbo_alignment.settings.pipelines.train.dpo import (
APODownLossSettings,
APOZeroLossSettings,
ASFTLossSettings,
CPOLossSettings,
DPOLossesType,
ASFTLossSettings,
HingeLossSettings,
IPOLossSettings,
KTOLossSettings,
Expand Down

0 comments on commit 05a6db0

Please sign in to comment.