diff --git a/InstructorEmbedding/instructor.py b/InstructorEmbedding/instructor.py index 72b3df2..7cfdaba 100644 --- a/InstructorEmbedding/instructor.py +++ b/InstructorEmbedding/instructor.py @@ -3,12 +3,13 @@ import json import os from collections import OrderedDict -from typing import Union +from typing import Literal, Union import numpy as np import torch from sentence_transformers import SentenceTransformer from sentence_transformers.models import Transformer +from sentence_transformers.util import load_file_path, load_dir_path from torch import Tensor, nn from tqdm.autonotebook import trange from transformers import AutoConfig, AutoTokenizer @@ -278,6 +279,7 @@ def __init__( max_seq_length=None, model_args=None, cache_dir=None, + backend: Literal['torch', 'onnx', 'openvino'] = 'torch', tokenizer_args=None, do_lower_case: bool = False, tokenizer_name_or_path: Union[str, None] = None, @@ -307,7 +309,7 @@ def __init__( ) if load_model: - self._load_model(self.model_name_or_path, config, cache_dir, **model_args) + self._load_model(self.model_name_or_path, config, cache_dir, **model_args, backend=backend) self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path if tokenizer_name_or_path is not None @@ -517,7 +519,9 @@ def smart_batching_collate(self, batch): return batched_input_features, labels - def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False): + def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, + trust_remote_code=None, local_files_only=None, model_kwargs=None, + tokenizer_kwargs=None, config_kwargs=None): """ Loads a full sentence-transformers model """ @@ -527,15 +531,15 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision= if os.path.isdir(model_path): model_path = str(model_path) else: - # If model_path is a Hugging Face repository ID, download the model - download_kwargs = { - "repo_id": model_path, - "revision": revision, - "library_name": "InstructorEmbedding", - "token": token, - "cache_dir": cache_folder, - "tqdm_class": disabled_tqdm, - } + model_path = load_dir_path( + model_path, + directory="*", + token=token, + cache_folder=cache_folder, + revision=revision, + local_files_only=local_files_only, + ) + model_path = model_path[:model_path.index("*")] # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) config_sentence_transformers_json_path = os.path.join( @@ -562,17 +566,19 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision= modules_config = json.load(config_file) modules = OrderedDict() + module_kwargs = OrderedDict() for module_config in modules_config: if module_config["idx"] == 0: module_class = INSTRUCTORTransformer elif module_config["idx"] == 1: module_class = INSTRUCTORPooling else: - module_class = import_from_string(module_config["type"]) - module = module_class.load(os.path.join(model_path, module_config["path"])) - modules[module_config["name"]] = module + module_class = import_from_string(module_config['type']) + module = module_class.load(os.path.join(model_path, module_config['path'])) + modules[module_config['name']] = module + module_kwargs[module_config['name']] = module - return modules + return modules, module_kwargs def encode( self, diff --git a/setup.py b/setup.py index 6f17fb0..ac4dd1b 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='InstructorEmbedding', packages=['InstructorEmbedding'], - version='1.0.2', + version='1.0.3', license='Apache License 2.0', description='Text embedding tool', long_description=readme,