From e81c66711c1b64294828529fc6b7df803973d1e9 Mon Sep 17 00:00:00 2001 From: Elisei Rykov Date: Sat, 19 Oct 2024 17:32:50 +0300 Subject: [PATCH] in progress --- turbo_alignment/dataset/multimodal/collators.py | 9 ++++----- turbo_alignment/dataset/multimodal/multimodal.py | 4 ++-- turbo_alignment/modeling/multimodal/lm/projection.py | 7 ++----- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/turbo_alignment/dataset/multimodal/collators.py b/turbo_alignment/dataset/multimodal/collators.py index 1be75a8..42ea89e 100644 --- a/turbo_alignment/dataset/multimodal/collators.py +++ b/turbo_alignment/dataset/multimodal/collators.py @@ -9,9 +9,7 @@ def torch_call(self, features): labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None if 'modality_inputs' in features[0].keys(): # print([feature['modality_inputs'] for feature in features]) - modality_inputs = torch.stack( - [feature['modality_inputs'] for feature in features] - ).contiguous() + modality_inputs = torch.stack([torch.stack(feature['modality_inputs']) for feature in features]) else: modality_inputs = [None for _ in features] @@ -56,7 +54,8 @@ def torch_call(self, features): ) for feature in features ] - ).contiguous() + ) + del features if labels is None: return batch @@ -65,6 +64,6 @@ def torch_call(self, features): label.tolist() + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels ] - batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64).contiguous() + batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) return batch diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index ccc9c2d..ee08321 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -202,7 +202,7 @@ def _read_modalities(self, record): if len(modality_encodings) != modality_messages_after_truncation: return None - return torch.stack(modality_encodings) + return modality_encodings def __iter__(self): worker_info = torch.utils.data.get_worker_info() @@ -258,7 +258,7 @@ def _read_modalities( print('inference reader') # modality_encodings.append((msg.type, reader.read(msg.content))) modality_encodings.append(reader.read(msg.content)) - return torch.stack(modality_encodings) + return modality_encodings def convert_records(self, records: list[MultimodalDatasetRecord]) -> list[dict[str, Any] | None]: chat_records = [self._convert_to_chat(r) for r in records] diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py index 33193b7..be95990 100644 --- a/turbo_alignment/modeling/multimodal/lm/projection.py +++ b/turbo_alignment/modeling/multimodal/lm/projection.py @@ -71,9 +71,6 @@ def convert_inputs_to_embeds( ) # returns mask with ids of spans from 1 to N modality_spans = find_objects(span_mask) # returns list of tuples with start index and end index - # print(len(sample_modality_inputs), len(modality_spans)) - # exit() - assert len(modality_spans) == len(sample_modality_inputs) # grouped_modality_encoder_inputs: dict[Modality, list[tuple[int, torch.Tensor]]] = defaultdict(list) @@ -83,7 +80,7 @@ def convert_inputs_to_embeds( for index, modality_object in enumerate(sample_modality_inputs): # modality, inputs = modality_object # grouped_modality_encoder_inputs[modality].append((index, inputs)) - inputs = modality_object.contiguous() + inputs = modality_object grouped_modality_encoder_inputs.append((index, inputs)) sorted_modality_embeddings: torch.Tensor = torch.full( @@ -131,7 +128,7 @@ def forward( modality_tokens_mask: torch.Tensor, labels: torch.LongTensor | None = None, ) -> ModelOutput: - multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs.contiguous(), modality_tokens_mask) + multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask) return self.language_model( inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask )