Skip to content

Commit

Permalink
in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Elisei Rykov committed Oct 19, 2024
1 parent cc53839 commit e81c667
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
9 changes: 4 additions & 5 deletions turbo_alignment/dataset/multimodal/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ def torch_call(self, features):
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
if 'modality_inputs' in features[0].keys():
# print([feature['modality_inputs'] for feature in features])
modality_inputs = torch.stack(
[feature['modality_inputs'] for feature in features]
).contiguous()
modality_inputs = torch.stack([torch.stack(feature['modality_inputs']) for feature in features])
else:
modality_inputs = [None for _ in features]

Expand Down Expand Up @@ -56,7 +54,8 @@ def torch_call(self, features):
)
for feature in features
]
).contiguous()
)
del features

if labels is None:
return batch
Expand All @@ -65,6 +64,6 @@ def torch_call(self, features):
label.tolist() + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]

batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64).contiguous()
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)

return batch
4 changes: 2 additions & 2 deletions turbo_alignment/dataset/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _read_modalities(self, record):
if len(modality_encodings) != modality_messages_after_truncation:
return None

return torch.stack(modality_encodings)
return modality_encodings

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
Expand Down Expand Up @@ -258,7 +258,7 @@ def _read_modalities(
print('inference reader')
# modality_encodings.append((msg.type, reader.read(msg.content)))
modality_encodings.append(reader.read(msg.content))
return torch.stack(modality_encodings)
return modality_encodings

def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]:
chat_records = [self._convert_to_chat(r) for r in records]
Expand Down
7 changes: 2 additions & 5 deletions turbo_alignment/modeling/multimodal/lm/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def convert_inputs_to_embeds(
) # returns mask with ids of spans from 1 to N
modality_spans = find_objects(span_mask) # returns list of tuples with start index and end index

# print(len(sample_modality_inputs), len(modality_spans))
# exit()

assert len(modality_spans) == len(sample_modality_inputs)

# grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list)
Expand All @@ -83,7 +80,7 @@ def convert_inputs_to_embeds(
for index, modality_object in enumerate(sample_modality_inputs):
# modality, inputs = modality_object
# grouped_modality_encoder_inputs[modality].append((index, inputs))
inputs = modality_object.contiguous()
inputs = modality_object
grouped_modality_encoder_inputs.append((index, inputs))

sorted_modality_embeddings: torch.Tensor = torch.full(
Expand Down Expand Up @@ -131,7 +128,7 @@ def forward(
modality_tokens_mask: torch.Tensor,
labels: torch.LongTensor | None = None,
) -> ModelOutput:
multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs.contiguous(), modality_tokens_mask)
multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask)
return self.language_model(
inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask
)

0 comments on commit e81c667

Please sign in to comment.