Skip to content

Commit

Permalink
If AutoModel is wrapped with PEFT for prompt learning, then extend th…
Browse files Browse the repository at this point in the history
…e attention mask
  • Loading branch information
tomaarsen committed Oct 18, 2024
1 parent 1802076 commit 9ddf2d5
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
from transformers.utils import is_peft_available

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -350,15 +351,31 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc
output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
output_tokens = output_states[0]

features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
# If the AutoModel is wrapped with a PeftModelForFeatureExtraction, then it may have added virtual tokens
# We need to extend the attention mask to include these virtual tokens, or the pooling will fail
if is_peft_available():
from peft import PeftModelForFeatureExtraction

if (
isinstance(self.auto_model, PeftModelForFeatureExtraction)
and self.auto_model.active_peft_config.is_prompt_learning
):
batch_size = output_tokens.size(0)
attention_mask = features["attention_mask"]
prefix_attention_mask = torch.ones(
batch_size, self.auto_model.active_peft_config.num_virtual_tokens, device=attention_mask.device
)
features["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

features["token_embeddings"] = output_tokens

if self.auto_model.config.output_hidden_states:
all_layer_idx = 2
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1

hidden_states = output_states[all_layer_idx]
features.update({"all_layer_embeddings": hidden_states})
features["all_layer_embeddings"] = hidden_states

return features

Expand Down

0 comments on commit 9ddf2d5

Please sign in to comment.