Skip to content

Commit

Permalink
Remove unrelated changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurCamara committed Sep 30, 2024
1 parent 4869ea5 commit 4a82531
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 64 deletions.
36 changes: 5 additions & 31 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
prompts: dict[str, str] | None = None,
default_prompt_name: str | None = None,
similarity_fn_name: str | SimilarityFunction | None = None,
mask_prompt: bool = False,
cache_folder: str | None = None,
trust_remote_code: bool = False,
revision: str | None = None,
Expand All @@ -167,11 +166,10 @@ def __init__(
config_kwargs: dict[str, Any] | None = None,
model_card_data: SentenceTransformerModelCardData | None = None,
) -> None:
# Note: self._load_sbert_model can also update `self.prompts`, `self.default_prompt_name` and `self.mask_prompt`
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
self.default_prompt_name = default_prompt_name
self.similarity_fn_name = similarity_fn_name
self.mask_prompt = mask_prompt
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim
self.model_card_data = model_card_data or SentenceTransformerModelCardData()
Expand Down Expand Up @@ -356,8 +354,6 @@ def __init__(
# suspect the user is using an INSTRUCTOR model.
if model_name_or_path in ("hkunlp/instructor-base", "hkunlp/instructor-large", "hkunlp/instructor-xl"):
self.set_pooling_include_prompt(include_prompt=False)
if self.mask_prompt:
self.set_pooling_mask_prompt(mask_prompt=True)
elif (
model_name_or_path
and "/" in model_name_or_path
Expand Down Expand Up @@ -570,7 +566,7 @@ def encode(

for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features = self.tokenize(sentences_batch, prompt_length=extra_features.get("prompt_length", None))
features = self.tokenize(sentences_batch)
if self.device.type == "hpu":
if "input_ids" in features:
curr_tokenize_len = features["input_ids"].shape
Expand Down Expand Up @@ -1002,24 +998,6 @@ def set_pooling_include_prompt(self, include_prompt: bool) -> None:
module.include_prompt = include_prompt
break

def set_pooling_mask_prompt(self, mask_prompt: bool) -> None:
"""
Sets the `mask_prompt` attribute in the pooling layer, if there is one.
This triggers the use of the `embed_mask` in the pooling, instead of the attention mask. This is useful for models, such as NV-Embed and LLM2Vec that masks the user's prompt.
Args:
mask_prompt (bool): Whether to mask the prompt in the model.
Returns:
None
"""
for module in self:
if isinstance(module, Pooling):
module.mask_prompt = mask_prompt
break

def get_max_seq_length(self) -> int | None:
"""
Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.
Expand All @@ -1032,9 +1010,7 @@ def get_max_seq_length(self) -> int | None:

return None

def tokenize(
self, texts: list[str] | list[dict] | list[tuple[str, str]], prompt_length: int | None = None
) -> dict[str, Tensor]:
def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dict[str, Tensor]:
"""
Tokenizes the texts.
Expand All @@ -1045,7 +1021,7 @@ def tokenize(
Dict[str, Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids",
"attention_mask", and "token_type_ids".
"""
return self._first_module().tokenize(texts, mask_prompt=self.mask_prompt, prompt_length=prompt_length)
return self._first_module().tokenize(texts)

def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], Tensor]:
return self._first_module().get_sentence_features(*features)
Expand Down Expand Up @@ -1147,7 +1123,6 @@ def save(
config["prompts"] = self.prompts
config["default_prompt_name"] = self.default_prompt_name
config["similarity_fn_name"] = self.similarity_fn_name
config["mask_prompt"] = self.mask_prompt
json.dump(config, fOut, indent=2)

# Save modules
Expand Down Expand Up @@ -1569,8 +1544,7 @@ def _load_sbert_model(
self.prompts = self._model_config.get("prompts", {})
if not self.default_prompt_name:
self.default_prompt_name = self._model_config.get("default_prompt_name", None)
if not self.mask_prompt:
self.mask_prompt = self._model_config.get("mask_prompt", False)


# Check if a readme exists
model_card_path = load_file_path(
Expand Down
8 changes: 2 additions & 6 deletions sentence_transformers/models/Pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_weightedmean_tokens: bool = False,
pooling_mode_lasttoken: bool = False,
include_prompt: bool = True,
mask_prompt: bool = False,
include_prompt = True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -97,7 +96,6 @@ def __init__(
self.pooling_mode_lasttoken = pooling_mode_lasttoken

self.include_prompt = include_prompt
self.mask_prompt = mask_prompt

pooling_mode_multiplier = sum(
[
Expand Down Expand Up @@ -139,9 +137,7 @@ def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
if "attention_mask" in features
else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64)
)
if self.mask_prompt:
attention_mask = features["prompt_mask"]
elif not self.include_prompt and "prompt_length" in features:
if not self.include_prompt and "prompt_length" in features:
attention_mask[:, : features["prompt_length"]] = 0

## Pooling strategy
Expand Down
36 changes: 9 additions & 27 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
if "token_type_ids" in features:
trans_features["token_type_ids"] = features["token_type_ids"]
if "prompt_mask" in features:
trans_features["attention_mask"] = features["prompt_mask"]
output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
output_tokens = output_states[0]

Expand All @@ -139,9 +137,7 @@ def get_word_embedding_dimension(self) -> int:
def tokenize(
self,
texts: list[str] | list[dict] | list[tuple[str, str]],
padding: str | bool = True,
mask_prompt: bool = False,
prompt_length: int | None = None,
padding: str | bool = True
) -> dict[str, torch.Tensor]:
"""Tokenizes a text and maps tokens to token-ids"""
output = {}
Expand Down Expand Up @@ -169,29 +165,15 @@ def tokenize(
if self.do_lower_case:
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]

tokens = self.tokenizer(
*to_tokenize,
padding=padding,
truncation="longest_first",
return_tensors="pt",
max_length=self.max_seq_length,
output.update(
self.tokenizer(
*to_tokenize,
padding=padding,
truncation="longest_first",
return_tensors="pt",
max_length=self.max_seq_length,
)
)
if mask_prompt and prompt_length is not None:
# We need to grab the first non-pad token in the input_ids and mask from there until prompt_length tokens
if self.tokenizer.padding_side == "left":
embed_mask = tokens["attention_mask"].clone()
pad_mask = tokens["input_ids"] == self.tokenizer.pad_token_id
first_non_pad_token = pad_mask.long().argmin(dim=1)
last_instruction_token = first_non_pad_token + prompt_length
mask_indices = torch.arange(0, embed_mask.shape[1]).long()[None, :] <= last_instruction_token[:, None]
embed_mask[mask_indices] = 0
tokens["prompt_mask"] = embed_mask
else:
# For righ-sided padding, just mask from the start until prompt_length
tokens["prompt_mask"] = torch.ones_like(tokens["attention_mask"])
tokens["prompt_mask"][:, :prompt_length] = 0

output.update(tokens)
return output

def get_config_dict(self) -> dict[str, Any]:
Expand Down

0 comments on commit 4a82531

Please sign in to comment.