Skip to content

Commit

Permalink
[ENH] Add optional kwargs when initialising SentenceTransformerEmbedd…
Browse files Browse the repository at this point in the history
…ingFunction class (#1891)

## Description of changes  
  
*Summarize the changes made by this PR.*  
 - Improvements & Bug fixes  
- Add optional kwargs for `SetenceTransformer` when initialising
`SentenceTransformerEmbeddingFunction` class (Issue
[#1857](#1857))
  
## Test plan  
*How are these changes tested?*  
  
- [x] Tests pass locally with `pytest` for python
- installing chroma as an editable package locally and testing with the
code
	```python
	import chromadb
	from chromadb.utils import embedding_functions
sentence_transformer_ef =
embedding_functions.SentenceTransformerEmbeddingFunction(prompts={"query":
"query: ", "passage": "passage: "})
	print(sentence_transformer_ef.models['all-MiniLM-L6-v2'].prompts)
	```
	returned
	```bash
	{'query': 'query: ', 'passage': 'passage: '}
	```
	 
## Documentation Changes  
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
I have added the documentation for
`SentenceTransformerEmbeddingFunction` initialisation.

Co-authored-by: sumaiyah <o##wJdTOvNkIC!1R@bQO>
  • Loading branch information
sumaiyah and sumaiyah authored Mar 22, 2024
1 parent 468cb4c commit 8f189d5
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,24 @@ def __init__(
model_name: str = "all-MiniLM-L6-v2",
device: str = "cpu",
normalize_embeddings: bool = False,
**kwargs: Any
):
"""Initialize SentenceTransformerEmbeddingFunction.
Args:
model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2"
device (str, optional): Device used for computation, defaults to "cpu"
normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False
**kwargs: Additional arguments to pass to the SentenceTransformer model.
"""
if model_name not in self.models:
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ValueError(
"The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`"
)
self.models[model_name] = SentenceTransformer(model_name, device=device)
self.models[model_name] = SentenceTransformer(model_name, device=device, **kwargs)
self._model = self.models[model_name]
self._normalize_embeddings = normalize_embeddings

Expand Down

0 comments on commit 8f189d5

Please sign in to comment.