diff --git a/chromadb/test/ef/test_ollama_ef.py b/chromadb/test/ef/test_ollama_ef.py new file mode 100644 index 00000000000..d44f1e8e6d1 --- /dev/null +++ b/chromadb/test/ef/test_ollama_ef.py @@ -0,0 +1,34 @@ +import os + +import pytest +import requests +from requests import HTTPError +from requests.exceptions import ConnectionError + +from chromadb.utils.embedding_functions import OllamaEmbeddingFunction + + +def test_ollama() -> None: + """ + To set up the Ollama server, follow instructions at: https://github.com/ollama/ollama?tab=readme-ov-file + Export the OLLAMA_SERVER_URL and OLLAMA_MODEL environment variables. + """ + if ( + os.environ.get("OLLAMA_SERVER_URL") is None + or os.environ.get("OLLAMA_MODEL") is None + ): + pytest.skip( + "OLLAMA_SERVER_URL or OLLAMA_MODEL environment variable not set. Skipping test." + ) + try: + response = requests.get(os.environ.get("OLLAMA_SERVER_URL", "")) + # If the response was successful, no Exception will be raised + response.raise_for_status() + except (HTTPError, ConnectionError): + pytest.skip("Ollama server not running. Skipping test.") + ef = OllamaEmbeddingFunction( + model_name=os.environ.get("OLLAMA_MODEL") or "nomic-embed-text", + url=f"{os.environ.get('OLLAMA_SERVER_URL')}/embeddings", + ) + embeddings = ef(["Here is an article about llamas...", "this is another article"]) + assert len(embeddings) == 2 diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index da5f1591f1c..22d57e6a3d6 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -61,7 +61,7 @@ def __init__( model_name: str = "all-MiniLM-L6-v2", device: str = "cpu", normalize_embeddings: bool = False, - **kwargs: Any + **kwargs: Any, ): """Initialize SentenceTransformerEmbeddingFunction. @@ -78,7 +78,9 @@ def __init__( raise ValueError( "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" ) - self.models[model_name] = SentenceTransformer(model_name, device=device, **kwargs) + self.models[model_name] = SentenceTransformer( + model_name, device=device, **kwargs + ) self._model = self.models[model_name] self._normalize_embeddings = normalize_embeddings @@ -828,6 +830,62 @@ def __call__(self, input: Documents) -> Embeddings: ) +class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). + """ + + def __init__(self, url: str, model_name: str) -> None: + """ + Initialize the Ollama Embedding Function. + + Args: + url (str): The URL of the Ollama Server. + model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). + """ + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. Please install it with `pip install requests`" + ) + self._api_url = f"{url}" + self._model_name = model_name + self._session = requests.Session() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = ollama_ef(texts) + """ + # Call Ollama Server API for each document + texts = input if isinstance(input, list) else [input] + embeddings = [ + self._session.post( + self._api_url, json={"model": self._model_name, "prompt": text} + ).json() + for text in texts + ] + return cast( + Embeddings, + [ + embedding["embedding"] + for embedding in embeddings + if "embedding" in embedding + ], + ) + + # List of all classes in this module _classes = [ name diff --git a/clients/js/src/embeddings/OllamaEmbeddingFunction.ts b/clients/js/src/embeddings/OllamaEmbeddingFunction.ts new file mode 100644 index 00000000000..bef8806f158 --- /dev/null +++ b/clients/js/src/embeddings/OllamaEmbeddingFunction.ts @@ -0,0 +1,34 @@ +import { IEmbeddingFunction } from "./IEmbeddingFunction"; + +export class OllamaEmbeddingFunction implements IEmbeddingFunction { + private readonly url: string; + private readonly model: string; + + constructor({ url, model }: { url: string, model: string }) { + // we used to construct the client here, but we need to async import the types + // for the openai npm package, and the constructor can not be async + this.url = url; + this.model = model; + } + + public async generate(texts: string[]) { + let embeddings:number[][] = []; + for (let text of texts) { + const response = await fetch(this.url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ 'model':this.model, 'prompt': text }) + }); + + if (!response.ok) { + throw new Error(`Failed to generate embeddings: ${response.status} (${response.statusText})`); + } + let finalResponse = await response.json(); + embeddings.push(finalResponse['embedding']); + } + return embeddings; + } + +} diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index 3eb0d0832b6..c925f9e4871 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -2,7 +2,6 @@ export { ChromaClient } from "./ChromaClient"; export { AdminClient } from "./AdminClient"; export { CloudClient } from "./CloudClient"; export { Collection } from "./Collection"; - export { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction"; export { OpenAIEmbeddingFunction } from "./embeddings/OpenAIEmbeddingFunction"; export { CohereEmbeddingFunction } from "./embeddings/CohereEmbeddingFunction"; @@ -11,6 +10,8 @@ export { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction" export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbeddingServerFunction"; export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction"; export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction"; +export { OllamaEmbeddingFunction } from './embeddings/OllamaEmbeddingFunction'; + export { IncludeEnum, diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index b569f36bedc..41b3de3fef5 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -5,6 +5,7 @@ import { METADATAS } from "./data"; import { IncludeEnum } from "../src/types"; import { OpenAIEmbeddingFunction } from "../src/embeddings/OpenAIEmbeddingFunction"; import { CohereEmbeddingFunction } from "../src/embeddings/CohereEmbeddingFunction"; +import { OllamaEmbeddingFunction } from "../src/embeddings/OllamaEmbeddingFunction"; test("it should add single embeddings to a collection", async () => { await chroma.reset(); const collection = await chroma.createCollection({ name: "test" }); @@ -120,3 +121,30 @@ test("should error on empty embedding", async () => { expect(e.message).toMatch("got empty embedding at pos"); } }); + +if (!process.env.OLLAMA_SERVER_URL) { + test.skip("it should use ollama EF, OLLAMA_SERVER_URL not defined", async () => {}); +} else { + test("it should use ollama EF", async () => { + await chroma.reset(); + const embedder = new OllamaEmbeddingFunction({ + url: + process.env.OLLAMA_SERVER_URL || + "http://127.0.0.1:11434/api/embeddings", + model: "nomic-embed-text", + }); + const collection = await chroma.createCollection({ + name: "test", + embeddingFunction: embedder, + }); + const embeddings = await embedder.generate(DOCUMENTS); + await collection.add({ ids: IDS, embeddings: embeddings }); + const count = await collection.count(); + expect(count).toBe(3); + var res = await collection.get({ + ids: IDS, + include: [IncludeEnum.Embeddings], + }); + expect(res.embeddings).toEqual(embeddings); // reverse because of the order of the ids + }); +} diff --git a/examples/use_with/ollama.md b/examples/use_with/ollama.md new file mode 100644 index 00000000000..7b6977fec7d --- /dev/null +++ b/examples/use_with/ollama.md @@ -0,0 +1,40 @@ +# Ollama + +First let's run a local docker container with Ollama. We'll pull `nomic-embed-text` model: + +```bash +docker run -d -v ./ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama +docker exec -it ollama ollama run nomic-embed-text # press Ctrl+D to exit after model downloads successfully +# test it +curl http://localhost:11434/api/embeddings -d '{"model": "nomic-embed-text","prompt": "Here is an article about llamas..."}' +``` + +Now let's configure our OllamaEmbeddingFunction Embedding (python) function with the default Ollama endpoint: + +```python +import chromadb +from chromadb.utils.embedding_functions import OllamaEmbeddingFunction + +client = chromadb.PersistentClient(path="ollama") + +# create EF with custom endpoint +ef = OllamaEmbeddingFunction( + model_name="nomic-embed-text", + url="http://127.0.0.1:11434/api/embeddings", +) + +print(ef(["Here is an article about llamas..."])) +``` + +For JS users, you can use the `OllamaEmbeddingFunction` class to create embeddings: + +```javascript +const {OllamaEmbeddingFunction} = require('chromadb'); +const embedder = new OllamaEmbeddingFunction({ + url: "http://127.0.0.1:11434/api/embeddings", + model: "llama2" +}) + +// use directly +const embeddings = embedder.generate(["Here is an article about llamas..."]) +```