Skip to content

Commit

Permalink
Add support for future model2vec version based on numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Oct 17, 2024
1 parent 4fec4ca commit 09499de
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sentence_transformers/models/StaticEmbedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def from_distillation(
apply_zipf=apply_zipf,
use_subword=use_subword,
)
embedding_weights = static_model.embedding.weight
if isinstance(static_model.embedding, np.ndarray):
embedding_weights = torch.from_numpy(static_model.embedding)
else:
embedding_weights = static_model.embedding.weight
tokenizer: Tokenizer = static_model.tokenizer

return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name)
Expand Down Expand Up @@ -202,7 +205,10 @@ def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding:
raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`")

static_model = StaticModel.from_pretrained(model_id_or_path)
embedding_weights = static_model.embedding.weight
if isinstance(static_model.embedding, np.ndarray):
embedding_weights = torch.from_numpy(static_model.embedding)
else:
embedding_weights = static_model.embedding.weight
tokenizer: Tokenizer = static_model.tokenizer

return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path)

0 comments on commit 09499de

Please sign in to comment.