Skip to content

Commit

Permalink
Fix handling of attention_mask in encoders (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Sep 9, 2024
1 parent 039b5d0 commit 1ad7199
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
9 changes: 5 additions & 4 deletions mmlearn/modules/encoders/clip_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
The text embeddings. Will be a tuple with a single element.
"""
input_ids = inputs[Modalities.TEXT]
attention_mask = inputs.get("attention_mask") or inputs.get(
Modalities.TEXT.attention_mask
attention_mask: Optional[torch.Tensor] = inputs.get(
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
)
position_ids = inputs.get("position_ids")

Expand Down Expand Up @@ -568,8 +568,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""
output = self.model(
input_ids=inputs[Modalities.TEXT],
attention_mask=inputs.get("attention_mask")
or inputs.get(Modalities.TEXT.attention_mask),
attention_mask=inputs.get(
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
),
inputs_embeds=inputs.get("inputs_embeds"),
output_attentions=inputs.get("output_attentions"),
output_hidden_states=True,
Expand Down
5 changes: 3 additions & 2 deletions mmlearn/modules/encoders/hf_text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""
outputs = self.model(
input_ids=inputs[Modalities.TEXT],
attention_mask=inputs.get("attention_mask")
or inputs.get(Modalities.TEXT.attention_mask),
attention_mask=inputs.get(
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
),
position_ids=inputs.get("position_ids"),
output_attentions=inputs.get("output_attentions"),
return_dict=True,
Expand Down
5 changes: 3 additions & 2 deletions projects/bioscan_clip/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
"""Run the forward pass."""
outputs = self.model(
input_ids=inputs[Modalities.DNA],
attention_mask=inputs.get("attention_mask")
or inputs.get(Modalities.DNA.attention_mask),
attention_mask=inputs.get(
"attention_mask", inputs.get(Modalities.DNA.attention_mask, None)
),
position_ids=inputs.get("position_ids"),
output_attentions=inputs.get("output_attentions"),
return_dict=True,
Expand Down

0 comments on commit 1ad7199

Please sign in to comment.