From caa3e126aa0e54662b692cb296d27854545c15d9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 10 Jul 2024 08:23:12 -0700 Subject: [PATCH] Add Ollama Embedding Driver --- CHANGELOG.md | 1 + .../drivers/embedding-drivers.md | 20 +++++++++++++ griptape/drivers/__init__.py | 2 ++ .../embedding/ollama_embedding_driver.py | 28 +++++++++++++++++++ poetry.lock | 14 ++++++++-- pyproject.toml | 1 + .../embedding/test_ollama_embedding_driver.py | 18 ++++++++++++ 7 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 griptape/drivers/embedding/ollama_embedding_driver.py create mode 100644 tests/unit/drivers/embedding/test_ollama_embedding_driver.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a84861a3..aa3b06a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- `OllamaEmbeddingDriver` for generating embeddings with Ollama. ### Changed diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 3f0135ac3..a607b0950 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -97,8 +97,28 @@ embeddings = driver.embed_string("Hello world!") # display the first 3 embeddings print(embeddings[:3]) +``` + +### Ollama + +!!! info + This driver requires the `drivers-embedding-ollama` [extra](../index.md#extras). + +The [OllamaEmbeddingDriver](../../reference/griptape/drivers/embedding/ollama_embedding_driver.md) uses the [Ollama Embeddings API](https://ollama.com/blog/embedding-models). + +```python title="PYTEST_IGNORE" +from griptape.drivers import OllamaEmbeddingDriver + +driver = OllamaEmbeddingDriver( + model="all-minilm", +) +results = driver.embed_string("Hello world!") + +# display the first 3 embeddings +print(results[:3]) ``` + ### Amazon SageMaker Jumpstart The [AmazonSageMakerJumpstartEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS. diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index fa2934a38..4e0fe6672 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -27,6 +27,7 @@ from .embedding.google_embedding_driver import GoogleEmbeddingDriver from .embedding.dummy_embedding_driver import DummyEmbeddingDriver from .embedding.cohere_embedding_driver import CohereEmbeddingDriver +from .embedding.ollama_embedding_driver import OllamaEmbeddingDriver from .vector.base_vector_store_driver import BaseVectorStoreDriver from .vector.local_vector_store_driver import LocalVectorStoreDriver @@ -135,6 +136,7 @@ "GoogleEmbeddingDriver", "DummyEmbeddingDriver", "CohereEmbeddingDriver", + "OllamaEmbeddingDriver", "BaseVectorStoreDriver", "LocalVectorStoreDriver", "PineconeVectorStoreDriver", diff --git a/griptape/drivers/embedding/ollama_embedding_driver.py b/griptape/drivers/embedding/ollama_embedding_driver.py new file mode 100644 index 000000000..d081321c1 --- /dev/null +++ b/griptape/drivers/embedding/ollama_embedding_driver.py @@ -0,0 +1,28 @@ +from __future__ import annotations +from typing import Optional, TYPE_CHECKING +from attrs import define, field, Factory +from griptape.utils import import_optional_dependency +from griptape.drivers import BaseEmbeddingDriver + +if TYPE_CHECKING: + from ollama import Client + + +@define +class OllamaEmbeddingDriver(BaseEmbeddingDriver): + """ + Attributes: + model: Ollama embedding model name. + host: Optional Ollama host. + client: Ollama `Client`. + """ + + model: str = field(kw_only=True, metadata={"serializable": True}) + host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + client: Client = field( + default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), + kw_only=True, + ) + + def try_embed_chunk(self, chunk: str) -> list[float]: + return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"]) diff --git a/poetry.lock b/poetry.lock index da5b152d2..1c4be2d3a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3466,7 +3466,6 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -4710,6 +4709,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4717,8 +4717,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4735,6 +4742,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4742,6 +4750,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -6517,6 +6526,7 @@ drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] drivers-embedding-google = ["google-generativeai"] drivers-embedding-huggingface = ["huggingface-hub", "transformers"] +drivers-embedding-ollama = ["ollama"] drivers-embedding-voyageai = ["voyageai"] drivers-event-listener-amazon-iot = ["boto3"] drivers-event-listener-amazon-sqs = ["boto3"] @@ -6555,4 +6565,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "a38d0f0671ff5deb42783721f4ce103e670a689f618e950451a807921f5d7139" +content-hash = "b0688b6e39e07dce28c6609067b46e5047c56c7428ca11f626ee862e3c8daf20" diff --git a/pyproject.toml b/pyproject.toml index ffcba09af..c74e703b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ drivers-embedding-huggingface = ["huggingface-hub", "transformers"] drivers-embedding-voyageai = ["voyageai"] drivers-embedding-google = ["google-generativeai"] drivers-embedding-cohere = ["cohere"] +drivers-embedding-ollama = ["ollama"] drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-scraper-markdownify = ["playwright", "beautifulsoup4", "markdownify"] diff --git a/tests/unit/drivers/embedding/test_ollama_embedding_driver.py b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py new file mode 100644 index 000000000..3886ab874 --- /dev/null +++ b/tests/unit/drivers/embedding/test_ollama_embedding_driver.py @@ -0,0 +1,18 @@ +import pytest +from griptape.drivers import OllamaEmbeddingDriver + + +class TestOllamaEmbeddingDriver: + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + mock_client = mocker.patch("ollama.Client") + + mock_client.return_value.embeddings.return_value = {"embedding": [0, 1, 0]} + + return mock_client + + def test_init(self): + assert OllamaEmbeddingDriver(model="foo") + + def test_try_embed_chunk(self): + assert OllamaEmbeddingDriver(model="foo").try_embed_chunk("foobar") == [0, 1, 0]