From 3204a26280dcfad084ca2364ffda1e691694d4f5 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 30 Oct 2024 12:47:11 +0100 Subject: [PATCH] Use a simple int as a prompt length instead of a tensor --- sentence_transformers/data_collator.py | 4 +- sentence_transformers/models/Pooling.py | 500 ++++++++++++------------ 2 files changed, 248 insertions(+), 256 deletions(-) diff --git a/sentence_transformers/data_collator.py b/sentence_transformers/data_collator.py index d9ac933c9..43474b754 100644 --- a/sentence_transformers/data_collator.py +++ b/sentence_transformers/data_collator.py @@ -58,9 +58,7 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: for key, value in tokenized.items(): batch[f"{column_name}_{key}"] = value if prompt_len is not None: - batch[f"{column_name}_prompt_length"] = torch.tensor( - [prompt_len] * len(values), device=batch[f"{column_name}_input_ids"].device, dtype=torch.long - ) + batch[f"{column_name}_prompt_length"] = prompt_len return batch def add_prompts_to_column( diff --git a/sentence_transformers/models/Pooling.py b/sentence_transformers/models/Pooling.py index e4a9e4dfc..ebfd54ba1 100644 --- a/sentence_transformers/models/Pooling.py +++ b/sentence_transformers/models/Pooling.py @@ -1,253 +1,247 @@ -from __future__ import annotations - -import json -import os -from typing import Any - -import torch -from torch import Tensor, nn - - -class Pooling(nn.Module): - """ - Performs pooling (max or mean) on the token embeddings. - - Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows - to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings - together. - - Args: - word_embedding_dimension: Dimensions for the word embeddings - pooling_mode: Either "cls", "lasttoken", "max", "mean", - "mean_sqrt_len_tokens", or "weightedmean". If set, - overwrites the other pooling_mode_* settings - pooling_mode_cls_token: Use the first token (CLS token) as text - representations - pooling_mode_max_tokens: Use max in each dimension over all - tokens. - pooling_mode_mean_tokens: Perform mean-pooling - pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but - divide by sqrt(input_length). - pooling_mode_weightedmean_tokens: Perform (position) weighted - mean pooling. See `SGPT: GPT Sentence Embeddings for - Semantic Search `_. - pooling_mode_lasttoken: Perform last token pooling. See `SGPT: - GPT Sentence Embeddings for Semantic Search - `_ and `Text and Code - Embeddings by Contrastive Pre-Training - `_. - """ - - POOLING_MODES = ( - "cls", - "lasttoken", - "max", - "mean", - "mean_sqrt_len_tokens", - "weightedmean", - ) - - def __init__( - self, - word_embedding_dimension: int, - pooling_mode: str = None, - pooling_mode_cls_token: bool = False, - pooling_mode_max_tokens: bool = False, - pooling_mode_mean_tokens: bool = True, - pooling_mode_mean_sqrt_len_tokens: bool = False, - pooling_mode_weightedmean_tokens: bool = False, - pooling_mode_lasttoken: bool = False, - include_prompt: bool = True, - ) -> None: - super().__init__() - - self.config_keys = [ - "word_embedding_dimension", - "pooling_mode_cls_token", - "pooling_mode_mean_tokens", - "pooling_mode_max_tokens", - "pooling_mode_mean_sqrt_len_tokens", - "pooling_mode_weightedmean_tokens", - "pooling_mode_lasttoken", - "include_prompt", - ] - - if pooling_mode is not None: # Set pooling mode by string - pooling_mode = pooling_mode.lower() - - if pooling_mode not in self.POOLING_MODES: - raise ValueError( - f"Set invalid pooling mode: {pooling_mode}. Valid pooling modes are: {self.POOLING_MODES}." - ) - - pooling_mode_cls_token = pooling_mode == "cls" - pooling_mode_max_tokens = pooling_mode == "max" - pooling_mode_mean_tokens = pooling_mode == "mean" - pooling_mode_mean_sqrt_len_tokens = pooling_mode == "mean_sqrt_len_tokens" - pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean" - pooling_mode_lasttoken = pooling_mode == "lasttoken" - - self.word_embedding_dimension = word_embedding_dimension - self.pooling_mode_cls_token = pooling_mode_cls_token - self.pooling_mode_mean_tokens = pooling_mode_mean_tokens - self.pooling_mode_max_tokens = pooling_mode_max_tokens - self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens - self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens - self.pooling_mode_lasttoken = pooling_mode_lasttoken - - self.include_prompt = include_prompt - - pooling_mode_multiplier = sum( - [ - pooling_mode_cls_token, - pooling_mode_max_tokens, - pooling_mode_mean_tokens, - pooling_mode_mean_sqrt_len_tokens, - pooling_mode_weightedmean_tokens, - pooling_mode_lasttoken, - ] - ) - self.pooling_output_dimension = pooling_mode_multiplier * word_embedding_dimension - - def __repr__(self) -> str: - return f"Pooling({self.get_config_dict()})" - - def get_pooling_mode_str(self) -> str: - """Returns the pooling mode as string""" - modes = [] - if self.pooling_mode_cls_token: - modes.append("cls") - if self.pooling_mode_mean_tokens: - modes.append("mean") - if self.pooling_mode_max_tokens: - modes.append("max") - if self.pooling_mode_mean_sqrt_len_tokens: - modes.append("mean_sqrt_len_tokens") - if self.pooling_mode_weightedmean_tokens: - modes.append("weightedmean") - if self.pooling_mode_lasttoken: - modes.append("lasttoken") - - return "+".join(modes) - - def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]: - token_embeddings = features["token_embeddings"] - attention_mask = ( - features["attention_mask"] - if "attention_mask" in features - else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64) - ) - if not self.include_prompt and "prompt_length" in features: - if isinstance(features["prompt_length"], Tensor): - seq_len = attention_mask.shape[1] - lens = features["prompt_length"] - mask_indices = torch.arange(0, seq_len).to(attention_mask.device)[None, :] <= lens[:, None] - attention_mask[mask_indices] = 0 - else: - attention_mask[:, : features["prompt_length"]] = 0 - - ## Pooling strategy - output_vectors = [] - if self.pooling_mode_cls_token: - cls_token = features.get("cls_token_embeddings", token_embeddings[:, 0]) # Take first token by default - output_vectors.append(cls_token) - if self.pooling_mode_max_tokens: - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) - ) - token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value - max_over_time = torch.max(token_embeddings, 1)[0] - output_vectors.append(max_over_time) - if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) - ) - sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) - - # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present - if "token_weights_sum" in features: - sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size()) - else: - sum_mask = input_mask_expanded.sum(1) - - sum_mask = torch.clamp(sum_mask, min=1e-9) - - if self.pooling_mode_mean_tokens: - output_vectors.append(sum_embeddings / sum_mask) - if self.pooling_mode_mean_sqrt_len_tokens: - output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) - if self.pooling_mode_weightedmean_tokens: - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) - ) - # token_embeddings shape: bs, seq, hidden_dim - weights = ( - torch.arange(start=1, end=token_embeddings.shape[1] + 1) - .unsqueeze(0) - .unsqueeze(-1) - .expand(token_embeddings.size()) - .to(token_embeddings.dtype) - .to(token_embeddings.device) - ) - assert weights.shape == token_embeddings.shape == input_mask_expanded.shape - input_mask_expanded = input_mask_expanded * weights - - sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) - - # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present - if "token_weights_sum" in features: - sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size()) - else: - sum_mask = input_mask_expanded.sum(1) - - sum_mask = torch.clamp(sum_mask, min=1e-9) - output_vectors.append(sum_embeddings / sum_mask) - if self.pooling_mode_lasttoken: - bs, seq_len, hidden_dim = token_embeddings.shape - # attention_mask shape: (bs, seq_len) - # Get shape [bs] indices of the last token (i.e. the last token for each batch item) - # Use flip and max() to get the last index of 1 in the attention mask - - if torch.jit.is_tracing(): - # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime: https://github.com/microsoft/onnxruntime/issues/10068 - attention_mask = attention_mask.to(torch.int32) - - values, indices = attention_mask.flip(1).max(1) - indices = torch.where(values == 0, seq_len - 1, indices) - gather_indices = seq_len - indices - 1 - - # Turn indices from shape [bs] --> [bs, 1, hidden_dim] - gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) - gather_indices = gather_indices.unsqueeze(1) - assert gather_indices.shape == (bs, 1, hidden_dim) - - # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim) - # Actually no need for the attention mask as we gather the last token where attn_mask = 1 - # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we - # use the attention mask to ignore them again - input_mask_expanded = ( - attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) - ) - embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1) - output_vectors.append(embedding) - - output_vector = torch.cat(output_vectors, 1) - features["sentence_embedding"] = output_vector - return features - - def get_sentence_embedding_dimension(self) -> int: - return self.pooling_output_dimension - - def get_config_dict(self) -> dict[str, Any]: - return {key: self.__dict__[key] for key in self.config_keys} - - def save(self, output_path) -> None: - with open(os.path.join(output_path, "config.json"), "w") as fOut: - json.dump(self.get_config_dict(), fOut, indent=2) - - @staticmethod - def load(input_path) -> Pooling: - with open(os.path.join(input_path, "config.json")) as fIn: - config = json.load(fIn) - - return Pooling(**config) +from __future__ import annotations + +import json +import os +from typing import Any + +import torch +from torch import Tensor, nn + + +class Pooling(nn.Module): + """ + Performs pooling (max or mean) on the token embeddings. + + Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows + to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings + together. + + Args: + word_embedding_dimension: Dimensions for the word embeddings + pooling_mode: Either "cls", "lasttoken", "max", "mean", + "mean_sqrt_len_tokens", or "weightedmean". If set, + overwrites the other pooling_mode_* settings + pooling_mode_cls_token: Use the first token (CLS token) as text + representations + pooling_mode_max_tokens: Use max in each dimension over all + tokens. + pooling_mode_mean_tokens: Perform mean-pooling + pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but + divide by sqrt(input_length). + pooling_mode_weightedmean_tokens: Perform (position) weighted + mean pooling. See `SGPT: GPT Sentence Embeddings for + Semantic Search `_. + pooling_mode_lasttoken: Perform last token pooling. See `SGPT: + GPT Sentence Embeddings for Semantic Search + `_ and `Text and Code + Embeddings by Contrastive Pre-Training + `_. + """ + + POOLING_MODES = ( + "cls", + "lasttoken", + "max", + "mean", + "mean_sqrt_len_tokens", + "weightedmean", + ) + + def __init__( + self, + word_embedding_dimension: int, + pooling_mode: str = None, + pooling_mode_cls_token: bool = False, + pooling_mode_max_tokens: bool = False, + pooling_mode_mean_tokens: bool = True, + pooling_mode_mean_sqrt_len_tokens: bool = False, + pooling_mode_weightedmean_tokens: bool = False, + pooling_mode_lasttoken: bool = False, + include_prompt: bool = True, + ) -> None: + super().__init__() + + self.config_keys = [ + "word_embedding_dimension", + "pooling_mode_cls_token", + "pooling_mode_mean_tokens", + "pooling_mode_max_tokens", + "pooling_mode_mean_sqrt_len_tokens", + "pooling_mode_weightedmean_tokens", + "pooling_mode_lasttoken", + "include_prompt", + ] + + if pooling_mode is not None: # Set pooling mode by string + pooling_mode = pooling_mode.lower() + + if pooling_mode not in self.POOLING_MODES: + raise ValueError( + f"Set invalid pooling mode: {pooling_mode}. Valid pooling modes are: {self.POOLING_MODES}." + ) + + pooling_mode_cls_token = pooling_mode == "cls" + pooling_mode_max_tokens = pooling_mode == "max" + pooling_mode_mean_tokens = pooling_mode == "mean" + pooling_mode_mean_sqrt_len_tokens = pooling_mode == "mean_sqrt_len_tokens" + pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean" + pooling_mode_lasttoken = pooling_mode == "lasttoken" + + self.word_embedding_dimension = word_embedding_dimension + self.pooling_mode_cls_token = pooling_mode_cls_token + self.pooling_mode_mean_tokens = pooling_mode_mean_tokens + self.pooling_mode_max_tokens = pooling_mode_max_tokens + self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens + self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens + self.pooling_mode_lasttoken = pooling_mode_lasttoken + + self.include_prompt = include_prompt + + pooling_mode_multiplier = sum( + [ + pooling_mode_cls_token, + pooling_mode_max_tokens, + pooling_mode_mean_tokens, + pooling_mode_mean_sqrt_len_tokens, + pooling_mode_weightedmean_tokens, + pooling_mode_lasttoken, + ] + ) + self.pooling_output_dimension = pooling_mode_multiplier * word_embedding_dimension + + def __repr__(self) -> str: + return f"Pooling({self.get_config_dict()})" + + def get_pooling_mode_str(self) -> str: + """Returns the pooling mode as string""" + modes = [] + if self.pooling_mode_cls_token: + modes.append("cls") + if self.pooling_mode_mean_tokens: + modes.append("mean") + if self.pooling_mode_max_tokens: + modes.append("max") + if self.pooling_mode_mean_sqrt_len_tokens: + modes.append("mean_sqrt_len_tokens") + if self.pooling_mode_weightedmean_tokens: + modes.append("weightedmean") + if self.pooling_mode_lasttoken: + modes.append("lasttoken") + + return "+".join(modes) + + def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]: + token_embeddings = features["token_embeddings"] + attention_mask = ( + features["attention_mask"] + if "attention_mask" in features + else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64) + ) + if not self.include_prompt and "prompt_length" in features: + attention_mask[:, : features["prompt_length"]] = 0 + + ## Pooling strategy + output_vectors = [] + if self.pooling_mode_cls_token: + cls_token = features.get("cls_token_embeddings", token_embeddings[:, 0]) # Take first token by default + output_vectors.append(cls_token) + if self.pooling_mode_max_tokens: + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) + ) + token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value + max_over_time = torch.max(token_embeddings, 1)[0] + output_vectors.append(max_over_time) + if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) + ) + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + + # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present + if "token_weights_sum" in features: + sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size()) + else: + sum_mask = input_mask_expanded.sum(1) + + sum_mask = torch.clamp(sum_mask, min=1e-9) + + if self.pooling_mode_mean_tokens: + output_vectors.append(sum_embeddings / sum_mask) + if self.pooling_mode_mean_sqrt_len_tokens: + output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) + if self.pooling_mode_weightedmean_tokens: + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) + ) + # token_embeddings shape: bs, seq, hidden_dim + weights = ( + torch.arange(start=1, end=token_embeddings.shape[1] + 1) + .unsqueeze(0) + .unsqueeze(-1) + .expand(token_embeddings.size()) + .to(token_embeddings.dtype) + .to(token_embeddings.device) + ) + assert weights.shape == token_embeddings.shape == input_mask_expanded.shape + input_mask_expanded = input_mask_expanded * weights + + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) + + # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present + if "token_weights_sum" in features: + sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size()) + else: + sum_mask = input_mask_expanded.sum(1) + + sum_mask = torch.clamp(sum_mask, min=1e-9) + output_vectors.append(sum_embeddings / sum_mask) + if self.pooling_mode_lasttoken: + bs, seq_len, hidden_dim = token_embeddings.shape + # attention_mask shape: (bs, seq_len) + # Get shape [bs] indices of the last token (i.e. the last token for each batch item) + # Use flip and max() to get the last index of 1 in the attention mask + + if torch.jit.is_tracing(): + # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime: https://github.com/microsoft/onnxruntime/issues/10068 + attention_mask = attention_mask.to(torch.int32) + + values, indices = attention_mask.flip(1).max(1) + indices = torch.where(values == 0, seq_len - 1, indices) + gather_indices = seq_len - indices - 1 + + # Turn indices from shape [bs] --> [bs, 1, hidden_dim] + gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) + gather_indices = gather_indices.unsqueeze(1) + assert gather_indices.shape == (bs, 1, hidden_dim) + + # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim) + # Actually no need for the attention mask as we gather the last token where attn_mask = 1 + # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we + # use the attention mask to ignore them again + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype) + ) + embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1) + output_vectors.append(embedding) + + output_vector = torch.cat(output_vectors, 1) + features["sentence_embedding"] = output_vector + return features + + def get_sentence_embedding_dimension(self) -> int: + return self.pooling_output_dimension + + def get_config_dict(self) -> dict[str, Any]: + return {key: self.__dict__[key] for key in self.config_keys} + + def save(self, output_path) -> None: + with open(os.path.join(output_path, "config.json"), "w") as fOut: + json.dump(self.get_config_dict(), fOut, indent=2) + + @staticmethod + def load(input_path) -> Pooling: + with open(os.path.join(input_path, "config.json")) as fIn: + config = json.load(fIn) + + return Pooling(**config)