diff --git a/turbo_alignment/dataset/multimodal/collators.py b/turbo_alignment/dataset/multimodal/collators.py index 5197124..3dd7967 100644 --- a/turbo_alignment/dataset/multimodal/collators.py +++ b/turbo_alignment/dataset/multimodal/collators.py @@ -10,7 +10,7 @@ def torch_call(self, features): if 'modality_inputs' in features[0].keys(): # print([feature['modality_inputs'] for feature in features]) modality_inputs = torch.stack( - [torch.stack(feature['modality_inputs']) for feature in features] + [torch.stack(feature['modality_inputs']).contiguous() for feature in features] ).contiguous() else: modality_inputs = [None for _ in features] diff --git a/turbo_alignment/dataset/multimodal/multimodal.py b/turbo_alignment/dataset/multimodal/multimodal.py index c100d39..ccbb65f 100644 --- a/turbo_alignment/dataset/multimodal/multimodal.py +++ b/turbo_alignment/dataset/multimodal/multimodal.py @@ -194,7 +194,7 @@ def _read_modalities(self, record): for msg in modality_messages: reader = self._modality_readers[msg.type] # modality_encodings.append((msg.type, reader.read(msg.content))) - modality_encodings.append(reader.read(msg.content)) + modality_encodings.append(reader.read(msg.content).contiguous()) except (OSError, RuntimeError, KeyError): return None diff --git a/turbo_alignment/modeling/multimodal/lm/projection.py b/turbo_alignment/modeling/multimodal/lm/projection.py index 653b285..33193b7 100644 --- a/turbo_alignment/modeling/multimodal/lm/projection.py +++ b/turbo_alignment/modeling/multimodal/lm/projection.py @@ -83,7 +83,7 @@ def convert_inputs_to_embeds( for index, modality_object in enumerate(sample_modality_inputs): # modality, inputs = modality_object # grouped_modality_encoder_inputs[modality].append((index, inputs)) - inputs = modality_object + inputs = modality_object.contiguous() grouped_modality_encoder_inputs.append((index, inputs)) sorted_modality_embeddings: torch.Tensor = torch.full( @@ -126,12 +126,12 @@ def convert_inputs_to_embeds( def forward( self, input_ids: torch.LongTensor, - modality_inputs: list[list[tuple[Modality, torch.Tensor]]], + modality_inputs, attention_mask: torch.LongTensor, modality_tokens_mask: torch.Tensor, labels: torch.LongTensor | None = None, ) -> ModelOutput: - multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs, modality_tokens_mask) + multimodal_lm_input_embeds = self.convert_inputs_to_embeds(input_ids, modality_inputs.contiguous(), modality_tokens_mask) return self.language_model( inputs_embeds=multimodal_lm_input_embeds, labels=labels, attention_mask=attention_mask ) diff --git a/turbo_alignment/pipelines/inference/multimodal.py b/turbo_alignment/pipelines/inference/multimodal.py index addaf11..b8701ce 100755 --- a/turbo_alignment/pipelines/inference/multimodal.py +++ b/turbo_alignment/pipelines/inference/multimodal.py @@ -40,6 +40,7 @@ def _load_model( peft=True, ) model.modality_adapters.load_state_dict(torch.load(inference_settings.model_settings.projections_path)) + for param in model.parameters(): param.data = param.data.contiguous() return model