Skip to content

Commit

Permalink
Merge branch 'dev' into feature/native-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 10, 2024
2 parents 29481f7 + c20617a commit 97d620f
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Native function calling support to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AnthropicPromptDriver`, `AmazonBedrockPromptDriver`, `GooglePromptDriver`, and `CoherePromptDriver`.
- `OllamaEmbeddingDriver` for generating embeddings with Ollama.

### Changed

Expand Down
20 changes: 20 additions & 0 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,28 @@ embeddings = driver.embed_string("Hello world!")

# display the first 3 embeddings
print(embeddings[:3])
```

### Ollama

!!! info
This driver requires the `drivers-embedding-ollama` [extra](../index.md#extras).

The [OllamaEmbeddingDriver](../../reference/griptape/drivers/embedding/ollama_embedding_driver.md) uses the [Ollama Embeddings API](https://ollama.com/blog/embedding-models).

```python title="PYTEST_IGNORE"
from griptape.drivers import OllamaEmbeddingDriver

driver = OllamaEmbeddingDriver(
model="all-minilm",
)

results = driver.embed_string("Hello world!")

# display the first 3 embeddings
print(results[:3])
```

### Amazon SageMaker Jumpstart

The [AmazonSageMakerJumpstartEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS.
Expand Down
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .embedding.google_embedding_driver import GoogleEmbeddingDriver
from .embedding.dummy_embedding_driver import DummyEmbeddingDriver
from .embedding.cohere_embedding_driver import CohereEmbeddingDriver
from .embedding.ollama_embedding_driver import OllamaEmbeddingDriver

from .vector.base_vector_store_driver import BaseVectorStoreDriver
from .vector.local_vector_store_driver import LocalVectorStoreDriver
Expand Down Expand Up @@ -135,6 +136,7 @@
"GoogleEmbeddingDriver",
"DummyEmbeddingDriver",
"CohereEmbeddingDriver",
"OllamaEmbeddingDriver",
"BaseVectorStoreDriver",
"LocalVectorStoreDriver",
"PineconeVectorStoreDriver",
Expand Down
28 changes: 28 additions & 0 deletions griptape/drivers/embedding/ollama_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from attrs import define, field, Factory
from griptape.utils import import_optional_dependency
from griptape.drivers import BaseEmbeddingDriver

if TYPE_CHECKING:
from ollama import Client


@define
class OllamaEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
model: Ollama embedding model name.
host: Optional Ollama host.
client: Ollama `Client`.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True),
kw_only=True,
)

def try_embed_chunk(self, chunk: str) -> list[float]:
return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
3 changes: 2 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ drivers-embedding-huggingface = ["huggingface-hub", "transformers"]
drivers-embedding-voyageai = ["voyageai"]
drivers-embedding-google = ["google-generativeai"]
drivers-embedding-cohere = ["cohere"]
drivers-embedding-ollama = ["ollama"]

drivers-web-scraper-trafilatura = ["trafilatura"]
drivers-web-scraper-markdownify = ["playwright", "beautifulsoup4", "markdownify"]
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/drivers/embedding/test_ollama_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from griptape.drivers import OllamaEmbeddingDriver


class TestOllamaEmbeddingDriver:
@pytest.fixture(autouse=True)
def mock_client(self, mocker):
mock_client = mocker.patch("ollama.Client")

mock_client.return_value.embeddings.return_value = {"embedding": [0, 1, 0]}

return mock_client

def test_init(self):
assert OllamaEmbeddingDriver(model="foo")

def test_try_embed_chunk(self):
assert OllamaEmbeddingDriver(model="foo").try_embed_chunk("foobar") == [0, 1, 0]

0 comments on commit 97d620f

Please sign in to comment.