diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 5e98588538d..da5f1591f1c 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -61,7 +61,16 @@ def __init__( model_name: str = "all-MiniLM-L6-v2", device: str = "cpu", normalize_embeddings: bool = False, + **kwargs: Any ): + """Initialize SentenceTransformerEmbeddingFunction. + + Args: + model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2" + device (str, optional): Device used for computation, defaults to "cpu" + normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False + **kwargs: Additional arguments to pass to the SentenceTransformer model. + """ if model_name not in self.models: try: from sentence_transformers import SentenceTransformer @@ -69,7 +78,7 @@ def __init__( raise ValueError( "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" ) - self.models[model_name] = SentenceTransformer(model_name, device=device) + self.models[model_name] = SentenceTransformer(model_name, device=device, **kwargs) self._model = self.models[model_name] self._normalize_embeddings = normalize_embeddings