diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 1a8cb2efb..4fd069f7d 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -1718,10 +1718,10 @@ def _load_sbert_model( # Try to initialize the module with a lot of kwargs, but only if the module supports them # Otherwise we fall back to the load method - # try: - module = module_class(model_name_or_path, cache_dir=cache_folder, backend=self.backend, **kwargs) - # except TypeError: - # module = module_class.load(model_name_or_path) + try: + module = module_class(model_name_or_path, cache_dir=cache_folder, backend=self.backend, **kwargs) + except TypeError: + module = module_class.load(model_name_or_path) else: # Normalize does not require any files to be loaded if module_class == Normalize: diff --git a/sentence_transformers/backend.py b/sentence_transformers/backend.py index eef76352e..355f40d83 100644 --- a/sentence_transformers/backend.py +++ b/sentence_transformers/backend.py @@ -78,7 +78,9 @@ def export_optimized_onnx_model( or not isinstance(model[0], Transformer) or not isinstance(model[0].auto_model, ORTModelForFeatureExtraction) ): - raise ValueError('The model must be a SentenceTransformer model loaded with `backend="onnx"`.') + raise ValueError( + 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="onnx"`.' + ) ort_model: ORTModelForFeatureExtraction = model[0].auto_model optimizer = ORTOptimizer.from_pretrained(ort_model) @@ -158,7 +160,9 @@ def export_dynamic_quantized_onnx_model( or not isinstance(model[0], Transformer) or not isinstance(model[0].auto_model, ORTModelForFeatureExtraction) ): - raise ValueError('The model must be a SentenceTransformer model loaded with `backend="onnx"`.') + raise ValueError( + 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="onnx"`.' + ) ort_model: ORTModelForFeatureExtraction = model[0].auto_model quantizer = ORTQuantizer.from_pretrained(ort_model)