Skip to content

Commit

Permalink
Updated to work with the latest version of sentence transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
SilasMarvin committed Apr 12, 2024
1 parent d92edb7 commit 989b34d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
15 changes: 14 additions & 1 deletion InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch import Tensor, nn
from tqdm.autonotebook import trange
from transformers import AutoConfig, AutoTokenizer
from sentence_transformers.util import disabled_tqdm

This comment has been minimized.

Copy link
@pascalhuerten

pascalhuerten Jul 8, 2024

This import seems to be not compatible with sentence-transformers==2.2.0.
Therefore requirements.txt should be updated to reflect the updated dependencies.

See error for reference when running train.py on sentence-transformers==2.2.0:

Traceback (most recent call last):
File "/content/instructor-embedding/train.py", line 16, in
from InstructorEmbedding import Instructor, InstructorTransformer
File "/content/instructor-embedding/InstructorEmbedding/init.py", line 1, in
from .instructor import *
File "/content/instructor-embedding/InstructorEmbedding/instructor.py", line 15, in
from sentence_transformers.util import disabled_tqdm
ImportError: cannot import name 'disabled_tqdm' from 'sentence_transformers.util' (/usr/local/lib/python3.10/dist-packages/sentence_transformers/util.py)

from huggingface_hub import snapshot_download


def batch_to_device(batch, target_device: str):
Expand Down Expand Up @@ -515,10 +517,21 @@ def smart_batching_collate(self, batch):

return batched_input_features, labels

def _load_sbert_model(self, model_path):
def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False):
"""
Loads a full sentence-transformers model
"""
# Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "sentence-transformers",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}
model_path = snapshot_download(**download_kwargs)

# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = os.path.join(
model_path, "config_sentence_transformers.json"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ sentence_transformers>=2.2.0
torch
tqdm
rich
tensorboard
tensorboard
huggingface-hub>=0.19.0

0 comments on commit 989b34d

Please sign in to comment.