Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ollama Embedding Driver #953

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.

### Changed

Expand Down
20 changes: 20 additions & 0 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,28 @@ embeddings = driver.embed_string("Hello world!")

# display the first 3 embeddings
print(embeddings[:3])
```

### Ollama

!!! info
This driver requires the `drivers-embedding-ollama` [extra](../index.md#extras).

The [OllamaEmbeddingDriver](../../reference/griptape/drivers/embedding/ollama_embedding_driver.md) uses the [Ollama Embeddings API](https://ollama.com/blog/embedding-models).

```python title="PYTEST_IGNORE"
from griptape.drivers import OllamaEmbeddingDriver

driver = OllamaEmbeddingDriver(
model="all-minilm",
)

results = driver.embed_string("Hello world!")

# display the first 3 embeddings
print(results[:3])
```

### Amazon SageMaker Jumpstart

The [AmazonSageMakerJumpstartEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS.
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .embedding.google_embedding_driver import GoogleEmbeddingDriver
from .embedding.dummy_embedding_driver import DummyEmbeddingDriver
from .embedding.cohere_embedding_driver import CohereEmbeddingDriver
from .embedding.ollama_embedding_driver import OllamaEmbeddingDriver

from .vector.base_vector_store_driver import BaseVectorStoreDriver
from .vector.local_vector_store_driver import LocalVectorStoreDriver
Expand Down Expand Up @@ -135,6 +136,7 @@
"GoogleEmbeddingDriver",
"DummyEmbeddingDriver",
"CohereEmbeddingDriver",
"OllamaEmbeddingDriver",
"BaseVectorStoreDriver",
"LocalVectorStoreDriver",
"PineconeVectorStoreDriver",
Expand Down
28 changes: 28 additions & 0 deletions griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.utils import import_optional_dependency
from griptape.drivers import BaseEmbeddingDriver

if TYPE_CHECKING:
from ollama import Client


@define
class OllamaEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
model: Ollama embedding model name.
host: Optional Ollama host.
client: Ollama `Client`.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
14 changes: 12 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ drivers-embedding-huggingface = ["huggingface-hub", "transformers"]
drivers-embedding-voyageai = ["voyageai"]
drivers-embedding-google = ["google-generativeai"]
drivers-embedding-cohere = ["cohere"]
drivers-embedding-ollama = ["ollama"]

drivers-web-scraper-trafilatura = ["trafilatura"]
drivers-web-scraper-markdownify = ["playwright", "beautifulsoup4", "markdownify"]
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/drivers/embedding/test_ollama_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from griptape.drivers import OllamaEmbeddingDriver


class TestOllamaEmbeddingDriver:
@pytest.fixture(autouse=True)
def mock_client(self, mocker):
mock_client = mocker.patch("ollama.Client")

mock_client.return_value.embeddings.return_value = {"embedding": [0, 1, 0]}

return mock_client

def test_init(self):
assert OllamaEmbeddingDriver(model="foo")

def test_try_embed_chunk(self):
assert OllamaEmbeddingDriver(model="foo").try_embed_chunk("foobar") == [0, 1, 0]
Loading