Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enh] Throw error if StaticEmbedding-based model is trained with incompatible loss #2990

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.checkpoint import get_device_states, set_device_states

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer
from sentence_transformers.models import StaticEmbedding, Transformer


class RandContext:
Expand Down Expand Up @@ -139,6 +139,11 @@ def __init__(
trainer.train()
"""
super().__init__()
if isinstance(model[0], StaticEmbedding):
raise ValueError(
"CachedGISTEmbedLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. "
"Consider using GISTEmbedLoss instead."
)
self.model = model
self.guide = guide
self.temperature = temperature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.checkpoint import get_device_states, set_device_states

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.models import StaticEmbedding


class RandContext:
Expand Down Expand Up @@ -145,6 +146,12 @@ def __init__(
trainer.train()
"""
super().__init__()
if isinstance(model[0], StaticEmbedding):
raise ValueError(
"CachedMultipleNegativesRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. "
"Consider using MultipleNegativesRankingLoss instead."
)

self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext
from sentence_transformers.models import StaticEmbedding


def _backward_hook(
Expand Down Expand Up @@ -114,6 +115,12 @@ def __init__(
- Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf
"""
super().__init__()
if isinstance(model[0], StaticEmbedding):
raise ValueError(
"CachedMultipleNegativesSymmetricRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. "
"Consider using MultipleNegativesSymmetricRankingLoss instead."
)

self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
Expand Down
7 changes: 7 additions & 0 deletions sentence_transformers/losses/DenoisingAutoEncoderLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,6 +74,12 @@ def __init__(
)
"""
super().__init__()

if isinstance(model[0], StaticEmbedding):
raise ValueError(
"DenoisingAutoEncoderLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding."
)

self.encoder = model # This will be the final model used during the inference time.
self.tokenizer_encoder = model.tokenizer

Expand Down
8 changes: 7 additions & 1 deletion sentence_transformers/losses/GISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor, nn

from sentence_transformers.models import Transformer
from sentence_transformers.models import StaticEmbedding, Transformer
from sentence_transformers.SentenceTransformer import SentenceTransformer


Expand Down Expand Up @@ -91,6 +91,12 @@ def __init__(
if self.must_retokenize:
self.tokenizer = self.model.tokenizer

if isinstance(self.model[0], StaticEmbedding):
raise ValueError(
"If we must retokenize because the guide model has a different tokenizer, "
"then the Sentence Transformer model must not be based on a StaticEmbedding."
)

def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))

Expand Down