Skip to content

Commit

Permalink
Remove mandatory dependency of tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Oct 18, 2024
1 parent 1802076 commit 245b884
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions sentence_transformers/models/StaticEmbedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import torch
from safetensors.torch import load_file as load_safetensors_file
from safetensors.torch import save_file as save_safetensors_file
from tokenizers import Tokenizer
from torch import nn
from transformers import PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerFast, is_tokenizers_available, requires_backends

from sentence_transformers.util import get_device_name

if is_tokenizers_available():
from tokenizers import Tokenizer


class StaticEmbedding(nn.Module):
def __init__(
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(
ValueError: If neither `embedding_weights` nor `embedding_dim` is provided.
"""
super().__init__()
requires_backends(self, "tokenizers")

if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = tokenizer._tokenizer
Expand Down Expand Up @@ -118,6 +121,7 @@ def save(self, save_dir: str, safe_serialization: bool = True, **kwargs) -> None
self.tokenizer.save(str(Path(save_dir) / "tokenizer.json"))

def load(load_dir: str, **kwargs) -> StaticEmbedding:
requires_backends(StaticEmbedding, "tokenizers")
tokenizer = Tokenizer.from_file(str(Path(load_dir) / "tokenizer.json"))
if os.path.exists(os.path.join(load_dir, "model.safetensors")):
weights = load_safetensors_file(os.path.join(load_dir, "model.safetensors"))
Expand Down Expand Up @@ -157,7 +161,7 @@ def from_distillation(
Raises:
ImportError: If the `model2vec` package is not installed.
"""

requires_backends(cls, "tokenizers")
try:
from model2vec import distill
except ImportError:
Expand Down Expand Up @@ -193,7 +197,7 @@ def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding:
Raises:
ImportError: If the `model2vec` package is not installed.
"""

requires_backends(cls, "tokenizers")
try:
from model2vec import StaticModel
except ImportError:
Expand Down

0 comments on commit 245b884

Please sign in to comment.