diff --git a/sentence_transformers/models/StaticEmbedding.py b/sentence_transformers/models/StaticEmbedding.py index de69285b2..dcb66e800 100644 --- a/sentence_transformers/models/StaticEmbedding.py +++ b/sentence_transformers/models/StaticEmbedding.py @@ -8,12 +8,14 @@ import torch from safetensors.torch import load_file as load_safetensors_file from safetensors.torch import save_file as save_safetensors_file -from tokenizers import Tokenizer from torch import nn -from transformers import PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerFast, is_tokenizers_available, requires_backends from sentence_transformers.util import get_device_name +if is_tokenizers_available(): + from tokenizers import Tokenizer + class StaticEmbedding(nn.Module): def __init__( @@ -60,6 +62,7 @@ def __init__( ValueError: If neither `embedding_weights` nor `embedding_dim` is provided. """ super().__init__() + requires_backends(self, "tokenizers") if isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = tokenizer._tokenizer @@ -118,6 +121,7 @@ def save(self, save_dir: str, safe_serialization: bool = True, **kwargs) -> None self.tokenizer.save(str(Path(save_dir) / "tokenizer.json")) def load(load_dir: str, **kwargs) -> StaticEmbedding: + requires_backends(StaticEmbedding, "tokenizers") tokenizer = Tokenizer.from_file(str(Path(load_dir) / "tokenizer.json")) if os.path.exists(os.path.join(load_dir, "model.safetensors")): weights = load_safetensors_file(os.path.join(load_dir, "model.safetensors")) @@ -157,7 +161,7 @@ def from_distillation( Raises: ImportError: If the `model2vec` package is not installed. """ - + requires_backends(cls, "tokenizers") try: from model2vec import distill except ImportError: @@ -193,7 +197,7 @@ def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding: Raises: ImportError: If the `model2vec` package is not installed. """ - + requires_backends(cls, "tokenizers") try: from model2vec import StaticModel except ImportError: