diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 5a99fa419..aa83c59e8 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -10,7 +10,7 @@ from torch.utils.checkpoint import get_device_states, set_device_states from sentence_transformers import SentenceTransformer -from sentence_transformers.models import Transformer +from sentence_transformers.models import StaticEmbedding, Transformer class RandContext: @@ -139,6 +139,11 @@ def __init__( trainer.train() """ super().__init__() + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "CachedGISTEmbedLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. " + "Consider using GISTEmbedLoss instead." + ) self.model = model self.guide = guide self.temperature = temperature diff --git a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py index c1e7d67c1..9c787fe8b 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesRankingLoss.py @@ -10,6 +10,7 @@ from torch.utils.checkpoint import get_device_states, set_device_states from sentence_transformers import SentenceTransformer, util +from sentence_transformers.models import StaticEmbedding class RandContext: @@ -145,6 +146,12 @@ def __init__( trainer.train() """ super().__init__() + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "CachedMultipleNegativesRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. " + "Consider using MultipleNegativesRankingLoss instead." + ) + self.model = model self.scale = scale self.similarity_fct = similarity_fct diff --git a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py index 83fe1e06f..ac82d133f 100644 --- a/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py +++ b/sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py @@ -10,6 +10,7 @@ from sentence_transformers import SentenceTransformer, util from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext +from sentence_transformers.models import StaticEmbedding def _backward_hook( @@ -114,6 +115,12 @@ def __init__( - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf """ super().__init__() + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "CachedMultipleNegativesSymmetricRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. " + "Consider using MultipleNegativesSymmetricRankingLoss instead." + ) + self.model = model self.scale = scale self.similarity_fct = similarity_fct diff --git a/sentence_transformers/losses/DenoisingAutoEncoderLoss.py b/sentence_transformers/losses/DenoisingAutoEncoderLoss.py index bb1cf8bef..8f38342d7 100644 --- a/sentence_transformers/losses/DenoisingAutoEncoderLoss.py +++ b/sentence_transformers/losses/DenoisingAutoEncoderLoss.py @@ -7,6 +7,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from sentence_transformers import SentenceTransformer +from sentence_transformers.models import StaticEmbedding logger = logging.getLogger(__name__) @@ -73,6 +74,12 @@ def __init__( ) """ super().__init__() + + if isinstance(model[0], StaticEmbedding): + raise ValueError( + "DenoisingAutoEncoderLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding." + ) + self.encoder = model # This will be the final model used during the inference time. self.tokenizer_encoder = model.tokenizer diff --git a/sentence_transformers/losses/GISTEmbedLoss.py b/sentence_transformers/losses/GISTEmbedLoss.py index f1bb833bd..51958da5e 100644 --- a/sentence_transformers/losses/GISTEmbedLoss.py +++ b/sentence_transformers/losses/GISTEmbedLoss.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn -from sentence_transformers.models import Transformer +from sentence_transformers.models import StaticEmbedding, Transformer from sentence_transformers.SentenceTransformer import SentenceTransformer @@ -91,6 +91,12 @@ def __init__( if self.must_retokenize: self.tokenizer = self.model.tokenizer + if isinstance(self.model[0], StaticEmbedding): + raise ValueError( + "If we must retokenize because the guide model has a different tokenizer, " + "then the Sentence Transformer model must not be based on a StaticEmbedding." + ) + def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor: return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))