diff --git a/turbo_alignment/dataset/multimodal/collators.py b/turbo_alignment/dataset/multimodal/collators.py index 42ea89e..d30fc07 100644 --- a/turbo_alignment/dataset/multimodal/collators.py +++ b/turbo_alignment/dataset/multimodal/collators.py @@ -8,8 +8,7 @@ def torch_call(self, features): label_name = 'label' if 'label' in features[0].keys() else 'labels' labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None if 'modality_inputs' in features[0].keys(): - # print([feature['modality_inputs'] for feature in features]) - modality_inputs = torch.stack([torch.stack(feature['modality_inputs']) for feature in features]) + modality_inputs = [feature['modality_inputs'] for feature in features] else: modality_inputs = [None for _ in features] @@ -55,7 +54,6 @@ def torch_call(self, features): for feature in features ] ) - del features if labels is None: return batch diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index ee08321..a86bf87 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -189,19 +189,20 @@ def _read_modalities(self, record): else: modality_messages = modality_messages[:modality_messages_after_truncation] - # modality_encodings: list[tuple[Modality, torch.Tensor]] = [] - modality_encodings = [] + modality_encodings: list[tuple[Modality, torch.Tensor]] = [] try: for msg in modality_messages: reader = self._modality_readers[msg.type] - # modality_encodings.append((msg.type, reader.read(msg.content))) - modality_encodings.append(reader.read(msg.content)) + modality_encodings.append((msg.type, reader.read(msg.content))) except (OSError, RuntimeError, KeyError): return None + # record['modality_inputs'] = modality_encodings + if len(modality_encodings) != modality_messages_after_truncation: return None + # return record return modality_encodings def __iter__(self): @@ -216,8 +217,7 @@ def __iter__(self): end = min(start + per_worker, end) for i, sample in enumerate(self.records[start:end]): output = self._read_modalities(sample) - - if output is not None: + if output: yield sample | {'modality_inputs': output} @@ -256,8 +256,8 @@ def _read_modalities( for msg in modality_messages: reader = self._modality_readers[msg.type] print('inference reader') - # modality_encodings.append((msg.type, reader.read(msg.content))) - modality_encodings.append(reader.read(msg.content)) + modality_encodings.append((msg.type, reader.read(msg.content))) + # modality_encodings.append(reader.read(msg.content)) return modality_encodings def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]: diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py index be95990..24a3832 100644 --- a/turbo_alignment/modeling/multimodal/lm/projection.py +++ b/turbo_alignment/modeling/multimodal/lm/projection.py @@ -73,42 +73,40 @@ def convert_inputs_to_embeds( assert len(modality_spans) == len(sample_modality_inputs) - # grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list) - grouped_modality_encoder_inputs = [] + grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list) # Prepare modality batches for index, modality_object in enumerate(sample_modality_inputs): - # modality, inputs = modality_object - # grouped_modality_encoder_inputs[modality].append((index, inputs)) - inputs = modality_object - grouped_modality_encoder_inputs.append((index, inputs)) + modality, inputs = modality_object + grouped_modality_encoder_inputs[modality].append((index, inputs)) sorted_modality_embeddings: torch.Tensor = torch.full( (len(sample_modality_inputs), self.n_modality_embs, self.language_model_dim), torch.nan ).to(self.language_model.device) # Encode modalities and insert into input embeds - # for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items(): - # for modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs: - modality_encoder_inputs_with_indices = grouped_modality_encoder_inputs - modality = Modality.IMAGE - modality_encoder_input_indexes, modality_encoder_inputs = zip(*modality_encoder_inputs_with_indices) - # modality_encoder_input_indexes, modality_encoder_inputs = modality_encoder_inputs_with_indices - - if self.language_model.dtype == torch.float32: - encoded_modality_object_batch = self.encoders[modality].encode( - torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device) + for modality, modality_encoder_inputs_with_indices in grouped_modality_encoder_inputs.items(): + modality_encoder_input_indexes, modality_encoder_inputs = zip(*modality_encoder_inputs_with_indices) + # print(modality_encoder_input_indexes) + # exit() + + if self.language_model.dtype == torch.float32: + encoded_modality_object_batch = self.encoders[modality].encode( + torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device) + ) + else: + encoded_modality_object_batch = self.encoders[modality].encode( + torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device).bfloat16() + ) + + modality_encoder_embeddings = self.modality_adapters[modality](encoded_modality_object_batch) + + # print(sorted_modality_embeddings[modality_encoder_input_indexes, :].shape, sorted_modality_embeddings.shape, modality_encoder_embeddings.shape) + # exit() + + sorted_modality_embeddings[modality_encoder_input_indexes, :] = modality_encoder_embeddings.to( + sorted_modality_embeddings.dtype ) - else: - encoded_modality_object_batch = self.encoders[modality].encode( - torch.stack(modality_encoder_inputs, dim=0).to(self.language_model.device).bfloat16() - ) - - modality_encoder_embeddings = self.modality_adapters[modality](encoded_modality_object_batch) - - sorted_modality_embeddings[modality_encoder_input_indexes, :] = modality_encoder_embeddings.to( - sorted_modality_embeddings.dtype - ) substituted_sample_lm_input_embeds = sample_lm_input_embeds.clone() for i, modality_span in enumerate(modality_spans): diff --git a/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py b/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py index 200134a..e2d26bc 100644 --- a/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py +++ b/turbo_alignment/modeling/multimodal/projectors/attention_pooling.py @@ -174,7 +174,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: projected_features[:, top_indices[:, self.top_k :].squeeze(0)] = 0 # set zero for unselected tokens projected_features = projected_features[(projected_features != 0).any(dim=-1)] # remove zero vectors + # print(self.top_k, projected_features.shape) + # exit() return projected_features.unsqueeze(0) + # return projected_features @MultiModalProjectorRegistry.register(ModalityProjectorType.THRESHOLD_SELECTOR)