diff --git a/CHANGELOG.md b/CHANGELOG.md index 4152922ae3..a9a43be273 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,20 +32,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Renamed parameter `s3_client` on `AmazonS3FileManagerDriver` to `client`. - **BREAKING**: Renamed parameter `s3_client` on `AwsS3Tool` to `client`. - **BREAKING**: Renamed parameter `pusher_client` on `PusherEventListenerDriver` to `client`. -- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`. -- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `client`. -- **BREAKING**: Renamed parameter `collection` on `AstraDbVectorStoreDriver` to `client`. - **BREAKING**: Renamed parameter `mq` on `MarqoVectorStoreDriver` to `client`. -- **BREAKING**: Renamed parameter `engine` on `PgVectorVectorStoreDriver` to `client`. -- **BREAKING**: Renamed parameter `index` on `PineconeVectorStoreDriver` to `client`. +- **BREAKING**: Renamed parameter `model_client` on `GooglePromptDriver` to `client`. - **BREAKING**: Renamed parameter `model_client` on `GoogleTokenizer` to `client`. +- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `text_generation_pipeline`. - Updated `JsonArtifact` value converter to properly handle more types. - `AudioArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - `ImageArtifact` now subclasses `BlobArtifact` instead of `MediaArtifact`. - Removed `__add__` method from `BaseArtifact`, implemented it where necessary. - Generic type support to `ListArtifact`. - Iteration support to `ListArtifact`. -- The `client` parameter on `Driver`s that use a client are now lazily initialized. +- Several places where API clients are initialized are now lazy loaded. ## [0.31.0] - 2024-09-03 diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index bfe1586eff..273db870b2 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -36,12 +36,12 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ), kw_only=True, ) - _client: TextGenerationPipeline = field( - default=None, kw_only=True, alias="client", metadata={"serializable": False} + _text_generation_pipeline: TextGenerationPipeline = field( + default=None, kw_only=True, alias="text_generation_pipeline", metadata={"serializable": False} ) @lazy_property() - def client(self) -> TextGenerationPipeline: + def text_generation_pipeline(self) -> TextGenerationPipeline: return import_optional_dependency("transformers").pipeline( "text-generation", model=self.model, @@ -53,7 +53,7 @@ def client(self) -> TextGenerationPipeline: def try_run(self, prompt_stack: PromptStack) -> Message: messages = self._prompt_stack_to_messages(prompt_stack) - result = self.client( + result = self.text_generation_pipeline( messages, max_new_tokens=self.max_tokens, temperature=self.temperature, diff --git a/griptape/drivers/vector/astradb_vector_store_driver.py b/griptape/drivers/vector/astradb_vector_store_driver.py index 92366f4bb7..85832be00e 100644 --- a/griptape/drivers/vector/astradb_vector_store_driver.py +++ b/griptape/drivers/vector/astradb_vector_store_driver.py @@ -9,8 +9,8 @@ from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from astrapy import Collection - from astrapy.authentication import TokenProvider + import astrapy + import astrapy.authentication @define @@ -27,33 +27,35 @@ class AstraDbVectorStoreDriver(BaseVectorStoreDriver): It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values. astra_db_namespace: optional specification of the namespace (in the Astra database) for the data. *Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store. + caller_name: the name of the caller for the Astra DB client. Defaults to "griptape". + client: an instance of `astrapy.DataAPIClient` for the Astra DB. + collection: an instance of `astrapy.Collection` for the Astra DB. """ api_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False}) + token: Optional[str | astrapy.authentication.TokenProvider] = field( + kw_only=True, default=None, metadata={"serializable": False} + ) collection_name: str = field(kw_only=True, metadata={"serializable": True}) environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - _client: Collection = field(default=None, kw_only=True, metadata={"serializable": False}) + caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False}) + _client: astrapy.DataAPIClient = field(default=None, kw_only=True, metadata={"serializable": False}) + _collection: astrapy.Collection = field(default=None, kw_only=True, metadata={"serializable": False}) @lazy_property() - def client(self) -> Collection: - astrapy = import_optional_dependency("astrapy") - return ( - astrapy.DataAPIClient( - caller_name="griptape", - environment=self.environment, - ) - .get_database( - self.api_endpoint, - token=self.token, - namespace=self.astra_db_namespace, - ) - .get_collection( - name=self.collection_name, - ) + def client(self) -> astrapy.DataAPIClient: + return import_optional_dependency("astrapy").DataAPIClient( + caller_name=self.caller_name, + environment=self.environment, ) + @lazy_property() + def collection(self) -> astrapy.Collection: + return self.client.get_database( + self.api_endpoint, token=self.token, namespace=self.astra_db_namespace + ).get_collection(self.collection_name) + def delete_vector(self, vector_id: str) -> None: """Delete a vector from Astra DB store. @@ -63,7 +65,7 @@ def delete_vector(self, vector_id: str) -> None: Args: vector_id: ID of the vector to delete. """ - self.client.delete_one({"_id": vector_id}) + self.collection.delete_one({"_id": vector_id}) def upsert_vector( self, @@ -94,10 +96,10 @@ def upsert_vector( if v is not None } if vector_id is not None: - self.client.find_one_and_replace({"_id": vector_id}, document, upsert=True) + self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True) return vector_id else: - insert_result = self.client.insert_one(document) + insert_result = self.collection.insert_one(document) return insert_result.inserted_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: @@ -111,7 +113,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None. """ find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None} - match = self.client.find_one(filter=find_filter, projection={"*": 1}) + match = self.collection.find_one(filter=find_filter, projection={"*": 1}) if match is not None: return BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace") @@ -133,7 +135,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace") ) - for match in self.client.find(filter=find_filter, projection={"*": 1}) + for match in self.collection.find(filter=find_filter, projection={"*": 1}) ] def query( @@ -166,7 +168,7 @@ def query( find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None vector = self.embedding_driver.embed_string(query) ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT - matches = self.client.find( + matches = self.collection.find( filter=find_filter, sort={"$vector": vector}, limit=ann_limit, diff --git a/griptape/drivers/vector/pgvector_vector_store_driver.py b/griptape/drivers/vector/pgvector_vector_store_driver.py index 553cca486f..1b2aa471db 100644 --- a/griptape/drivers/vector/pgvector_vector_store_driver.py +++ b/griptape/drivers/vector/pgvector_vector_store_driver.py @@ -12,7 +12,7 @@ from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from sqlalchemy.engine import Engine + import sqlalchemy @define @@ -30,12 +30,12 @@ class PgVectorVectorStoreDriver(BaseVectorStoreDriver): create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) table_name: str = field(kw_only=True, metadata={"serializable": True}) _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True)) - _client: Engine = field(default=None, kw_only=True, metadata={"serializable": False}) + _sqlalchemy_engine: sqlalchemy.Engine = field(default=None, kw_only=True, metadata={"serializable": False}) @connection_string.validator # pyright: ignore[reportAttributeAccessIssue] def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. - if self._client is not None: + if self._sqlalchemy_engine is not None: return # If an engine is not provided, a connection string is required. @@ -46,7 +46,7 @@ def validate_connection_string(self, _: Attribute, connection_string: Optional[s raise ValueError("The connection string must describe a Postgres database connection") @lazy_property() - def client(self) -> Engine: + def sqlalchemy_engine(self) -> sqlalchemy.Engine: return import_optional_dependency("sqlalchemy").create_engine( self.connection_string, **self.create_engine_params ) @@ -62,15 +62,15 @@ def setup( sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql") if install_uuid_extension: - with self.client.begin() as conn: + with self.sqlalchemy_engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')) if install_vector_extension: - with self.client.begin() as conn: + with self.sqlalchemy_engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";')) if create_schema: - self._model.metadata.create_all(self.client) + self._model.metadata.create_all(self.sqlalchemy_engine) def upsert_vector( self, @@ -84,7 +84,7 @@ def upsert_vector( """Inserts or updates a vector in the collection.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") - with sqlalchemy_orm.Session(self.client) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs) obj = session.merge(obj) @@ -96,7 +96,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Base """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") - with sqlalchemy_orm.Session(self.client) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: result = session.get(self._model, vector_id) return BaseVectorStoreDriver.Entry( @@ -110,7 +110,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") - with sqlalchemy_orm.Session(self.client) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: query = session.query(self._model) if namespace: query = query.filter_by(namespace=namespace) @@ -151,7 +151,7 @@ def query( op = distance_metrics[distance_metric] - with sqlalchemy_orm.Session(self.client) as session: + with sqlalchemy_orm.Session(self.sqlalchemy_engine) as session: vector = self.embedding_driver.embed_string(query) # The query should return both the vector and the distance metric score. diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index 0ea997a213..500b090f58 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -18,20 +18,21 @@ class PineconeVectorStoreDriver(BaseVectorStoreDriver): index_name: str = field(kw_only=True, metadata={"serializable": True}) environment: str = field(kw_only=True, metadata={"serializable": True}) project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - _client: pinecone.Index = field(default=None, kw_only=True, metadata={"serializable": False}) + _client: pinecone.Pinecone = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + _index: pinecone.Index = field(default=None, kw_only=True, alias="index", metadata={"serializable": False}) @lazy_property() - def client(self) -> pinecone.Index: - return ( - import_optional_dependency("pinecone") - .Pinecone( - api_key=self.api_key, - environment=self.environment, - project_name=self.project_name, - ) - .Index(self.index_name) + def client(self) -> pinecone.Pinecone: + return import_optional_dependency("pinecone").Pinecone( + api_key=self.api_key, + environment=self.environment, + project_name=self.project_name, ) + @lazy_property() + def index(self) -> pinecone.Index: + return self.client.get_index(self.index_name) + def upsert_vector( self, vector: list[float], @@ -44,12 +45,12 @@ def upsert_vector( params: dict[str, Any] = {"namespace": namespace} | kwargs - self.client.upsert(vectors=[(vector_id, vector, meta)], **params) + self.index.upsert(vectors=[(vector_id, vector, meta)], **params) return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: - result = self.client.fetch(ids=[vector_id], namespace=namespace).to_dict() + result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() vectors = list(result["vectors"].values()) if len(vectors) > 0: @@ -69,7 +70,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto # all values from a namespace: # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5 - results = self.client.query( + results = self.index.query( vector=self.embedding_driver.embed_string(""), top_k=10000, include_metadata=True, @@ -105,7 +106,7 @@ def query( "include_metadata": include_metadata, } | kwargs - results = self.client.query(vector=vector, **params) + results = self.index.query(vector=vector, **params) return [ BaseVectorStoreDriver.Entry( diff --git a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py index e0668e4847..b46ce64512 100644 --- a/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_pgvector_vector_store_driver.py @@ -18,7 +18,7 @@ def embedding_driver(self): return MockEmbeddingDriver() @pytest.fixture() - def mock_client(self): + def mock_engine(self): return MagicMock() @pytest.fixture() @@ -30,25 +30,27 @@ def mock_session(self, mocker): return session - def test_initialize_requires_client_or_connection_string(self, embedding_driver): + def test_initialize_requires_engine_or_connection_string(self, embedding_driver): with pytest.raises(ValueError): PgVectorVectorStoreDriver(embedding_driver=embedding_driver, table_name=self.table_name) - def test_initialize_accepts_client(self, embedding_driver): - client: Any = create_engine(self.connection_string) - PgVectorVectorStoreDriver(embedding_driver=embedding_driver, client=client, table_name=self.table_name) + def test_initialize_accepts_engine(self, embedding_driver): + engine: Any = create_engine(self.connection_string) + PgVectorVectorStoreDriver( + embedding_driver=embedding_driver, sqlalchemy_engine=engine, table_name=self.table_name + ) def test_initialize_accepts_connection_string(self, embedding_driver): PgVectorVectorStoreDriver( embedding_driver=embedding_driver, connection_string=self.connection_string, table_name=self.table_name ) - def test_upsert_vector(self, mock_session, mock_client): + def test_upsert_vector(self, mock_session, mock_engine): test_id = str(uuid.uuid4()) mock_session.merge.return_value = Mock(id=test_id) driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), client=mock_client, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) returned_id = driver.upsert_vector([1.0, 2.0, 3.0]) @@ -57,7 +59,7 @@ def test_upsert_vector(self, mock_session, mock_client): mock_session.merge.assert_called_once() mock_session.commit.assert_called_once() - def test_load_entry(self, mock_session, mock_client): + def test_load_entry(self, mock_session, mock_engine): test_id = str(uuid.uuid4()) test_vec = [0.1, 0.2, 0.3] test_namespace = str(uuid.uuid4()) @@ -65,7 +67,7 @@ def test_load_entry(self, mock_session, mock_client): mock_session.get.return_value = Mock(id=test_id, vector=test_vec, namespace=test_namespace, meta=test_meta) driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), client=mock_client, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) entry = driver.load_entry(vector_id=test_id) @@ -75,7 +77,7 @@ def test_load_entry(self, mock_session, mock_client): assert entry.namespace == test_namespace assert entry.meta == test_meta - def test_load_entries(self, mock_session, mock_client): + def test_load_entries(self, mock_session, mock_engine): test_ids = [str(uuid.uuid4()), str(uuid.uuid4())] test_vecs = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] test_namespaces = [str(uuid.uuid4()), str(uuid.uuid4())] @@ -88,7 +90,7 @@ def test_load_entries(self, mock_session, mock_client): mock_session.query.return_value = mock_query driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), client=mock_client, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) entries = driver.load_entries() @@ -102,15 +104,15 @@ def test_load_entries(self, mock_session, mock_client): assert entries[0].meta == test_metas[0] assert entries[1].meta == test_metas[1] - def test_query_invalid_distance_metric(self, mock_client): + def test_query_invalid_distance_metric(self, mock_engine): driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), client=mock_client, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) with pytest.raises(ValueError): driver.query("test", distance_metric="invalid") - def test_query(self, mock_session, mock_client): + def test_query(self, mock_session, mock_engine): test_ids = [str(uuid.uuid4()), str(uuid.uuid4())] test_vecs = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] test_namespaces = [str(uuid.uuid4()), str(uuid.uuid4())] @@ -122,7 +124,7 @@ def test_query(self, mock_session, mock_client): mock_session.query().order_by().limit().all.return_value = test_result driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), client=mock_client, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) result = driver.query("some query", include_vectors=True) @@ -136,7 +138,7 @@ def test_query(self, mock_session, mock_client): assert result[0].meta == test_metas[0] assert result[1].meta == test_metas[1] - def test_query_filter(self, mock_session, mock_client): + def test_query_filter(self, mock_session, mock_engine): test_ids = [str(uuid.uuid4()), str(uuid.uuid4())] test_vecs = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] test_namespaces = [str(uuid.uuid4()), str(uuid.uuid4())] @@ -147,7 +149,7 @@ def test_query_filter(self, mock_session, mock_client): mock_session.query().order_by().filter_by().limit().all.return_value = test_result driver = PgVectorVectorStoreDriver( - embedding_driver=MockEmbeddingDriver(), client=mock_client, table_name=self.table_name + embedding_driver=MockEmbeddingDriver(), sqlalchemy_engine=mock_engine, table_name=self.table_name ) result = driver.query("some query", include_vectors=True, filter={"namespace": test_namespaces[0]}) diff --git a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py index 0726a0c7ea..a963fb370e 100644 --- a/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py +++ b/tests/unit/drivers/vector/test_pinecone_vector_storage_driver.py @@ -7,7 +7,7 @@ class TestPineconeVectorStorageDriver: @pytest.fixture(autouse=True) - def _mock_pinecone(self, mocker): + def mock_index(self, mocker): # Create a fake response fake_query_response = { "matches": [{"id": "foo", "values": [0, 1, 0], "score": 42, "metadata": {"foo": "bar"}}], @@ -15,14 +15,21 @@ def _mock_pinecone(self, mocker): } mock_client = mocker.patch("pinecone.Pinecone") - mock_client().Index().upsert.return_value = None - mock_client().Index().query.return_value = fake_query_response - mock_client().create_index.return_value = None + mock_index = mock_client().Index() + mock_index.upsert.return_value = None + mock_index.query.return_value = fake_query_response + mock_index.create_index.return_value = None + + return mock_index @pytest.fixture() - def driver(self): + def driver(self, mock_index): return PineconeVectorStoreDriver( - api_key="foobar", index_name="test", environment="test", embedding_driver=MockEmbeddingDriver() + api_key="foobar", + index_name="test", + environment="test", + embedding_driver=MockEmbeddingDriver(), + index=mock_index, ) def test_upsert_text_artifact(self, driver):