diff --git a/CHANGELOG.md b/CHANGELOG.md index e247a2d54..e5f638411 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`. +- `OllamaEmbeddingDriver` for generating embeddings with Ollama. ### Changed diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 3f0135ac3..a607b0950 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -97,8 +97,28 @@ embeddings = driver.embed_string("Hello world!") # display the first 3 embeddings print(embeddings[:3]) +``` + +### Ollama + +!!! info + This driver requires the `drivers-embedding-ollama` [extra](../index.md#extras). + +The [OllamaEmbeddingDriver](../../reference/griptape/drivers/embedding/ollama_embedding_driver.md) uses the [Ollama Embeddings API](https://ollama.com/blog/embedding-models). + +```python title="PYTEST_IGNORE" +from griptape.drivers import OllamaEmbeddingDriver + +driver = OllamaEmbeddingDriver( + model="all-minilm", +) +results = driver.embed_string("Hello world!") + +# display the first 3 embeddings +print(results[:3]) ``` + ### Amazon SageMaker Jumpstart The [AmazonSageMakerJumpstartEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS. diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index fa2934a38..4e0fe6672 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -27,6 +27,7 @@ from .embedding.google_embedding_driver import GoogleEmbeddingDriver from .embedding.dummy_embedding_driver import DummyEmbeddingDriver from .embedding.cohere_embedding_driver import CohereEmbeddingDriver +from .embedding.ollama_embedding_driver import OllamaEmbeddingDriver from .vector.base_vector_store_driver import BaseVectorStoreDriver from .vector.local_vector_store_driver import LocalVectorStoreDriver @@ -135,6 +136,7 @@ "GoogleEmbeddingDriver", "DummyEmbeddingDriver", "CohereEmbeddingDriver", + "OllamaEmbeddingDriver", "BaseVectorStoreDriver", "LocalVectorStoreDriver", "PineconeVectorStoreDriver", diff --git a/griptape/drivers/embedding/ollama_embedding_driver.py b/griptape/drivers/embedding/ollama_embedding_driver.py new file mode 100644 index 000000000..d081321c1 --- /dev/null +++ b/griptape/drivers/embedding/ollama_embedding_driver.py @@ -0,0 +1,28 @@ +from __future__ import annotations +from typing import Optional, TYPE_CHECKING +from attrs import define, field, Factory +from griptape.utils import import_optional_dependency +from griptape.drivers import BaseEmbeddingDriver + +if TYPE_CHECKING: + from ollama import Client + + +@define +class OllamaEmbeddingDriver(BaseEmbeddingDriver): + """ + Attributes: + model: Ollama embedding model name. + host: Optional Ollama host. + client: Ollama `Client`. + """ + + model: str = field(kw_only=True, metadata={"serializable": True}) + host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + client: Client = field( + default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), + kw_only=True, + ) + + def try_embed_chunk(self, chunk: str) -> list[float]: + return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"]) diff --git a/poetry.lock b/poetry.lock index 4d9ad0f24..ea89ac4b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6526,6 +6526,7 @@ drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] drivers-embedding-google = ["google-generativeai"] drivers-embedding-huggingface = ["huggingface-hub", "transformers"] +drivers-embedding-ollama = ["ollama"] drivers-embedding-voyageai = ["voyageai"] drivers-event-listener-amazon-iot = ["boto3"] drivers-event-listener-amazon-sqs = ["boto3"] @@ -6564,4 +6565,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "38b27b1953b5b0e75a285951390e1cbc6724d1d6f015059fde2e15733a308097" +content-hash = "cdcfb6d4a27cbce4b7538548c1e13a361bf2e14cb0c722ba5fdf6e84e7d37441" diff --git a/pyproject.toml b/pyproject.toml index b0b28a8f6..64a7ad585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ drivers-embedding-huggingface = ["huggingface-hub", "transformers"] drivers-embedding-voyageai = ["voyageai"] drivers-embedding-google = ["google-generativeai"] drivers-embedding-cohere = ["cohere"] +drivers-embedding-ollama = ["ollama"] drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-scraper-markdownify = ["playwright", "beautifulsoup4", "markdownify"] diff --git a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py new file mode 100644 index 000000000..3886ab874 --- /dev/null +++ b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py @@ -0,0 +1,18 @@ +import pytest +from griptape.drivers import OllamaEmbeddingDriver + + +class TestOllamaEmbeddingDriver: + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + mock_client = mocker.patch("ollama.Client") + + mock_client.return_value.embeddings.return_value = {"embedding": [0, 1, 0]} + + return mock_client + + def test_init(self): + assert OllamaEmbeddingDriver(model="foo") + + def test_try_embed_chunk(self): + assert OllamaEmbeddingDriver(model="foo").try_embed_chunk("foobar") == [0, 1, 0]