From 1faa69ec7fa860e265c8388746779e7d89a1077d Mon Sep 17 00:00:00 2001 From: Hammad Bashir Date: Thu, 9 Nov 2023 17:22:56 -0800 Subject: [PATCH] [BUG] Fix multitenancy (#1372) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Fixes a bug with multitenancy where name= filtering path is always being called - New functionality - ... ## Test plan *How are these changes tested?* I added tests to verify collections can be added to across database and tenants ## Documentation Changes None reauired --- chromadb/db/mixins/sysdb.py | 4 +- chromadb/test/client/test_database_tenant.py | 91 ++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/chromadb/db/mixins/sysdb.py b/chromadb/db/mixins/sysdb.py index d22960d03f4..e609ea80b19 100644 --- a/chromadb/db/mixins/sysdb.py +++ b/chromadb/db/mixins/sysdb.py @@ -414,7 +414,9 @@ def get_collections( if name: q = q.where(collections_t.name == ParameterValue(name)) - if tenant and database: + # Only if we have a name, tenant and database do we need to filter databases + # Given an id, we can uniquely identify the collection so we don't need to filter databases + if id is None and tenant and database: databases_t = Table("databases") q = q.where( collections_t.database_id diff --git a/chromadb/test/client/test_database_tenant.py b/chromadb/test/client/test_database_tenant.py index 00c502bb0ea..eb20265331e 100644 --- a/chromadb/test/client/test_database_tenant.py +++ b/chromadb/test/client/test_database_tenant.py @@ -63,6 +63,97 @@ def test_database_tenant_collections(client: Client) -> None: assert len(collections) == 0 +def test_database_collections_add(client: Client) -> None: + client.reset() + + # Create a new database in the default tenant + admin_client = AdminClient.from_system(client._system) + admin_client.create_database("test_db") + + # Create collections in this new database + client.set_database(database="test_db") + coll_new = client.create_collection("collection_new") + + # Create collections in the default database + client.set_database(database=DEFAULT_DATABASE) + coll_default = client.create_collection("collection_default") + + records_new = { + "ids": ["a", "b", "c"], + "embeddings": [[1.0, 2.0, 3.0] for _ in range(3)], + "documents": ["a", "b", "c"], + } + + records_default = { + "ids": ["c", "d", "e"], + "embeddings": [[4.0, 5.0, 6.0] for _ in range(3)], + "documents": ["c", "d", "e"], + } + + # Add to the new coll + coll_new.add(**records_new) # type: ignore + + # Add to the default coll + coll_default.add(**records_default) # type: ignore + + # Make sure the collections are isolated + res = coll_new.get(include=["embeddings", "documents"]) + assert res["ids"] == records_new["ids"] + assert res["embeddings"] == records_new["embeddings"] + assert res["documents"] == records_new["documents"] + + res = coll_default.get(include=["embeddings", "documents"]) + assert res["ids"] == records_default["ids"] + assert res["embeddings"] == records_default["embeddings"] + assert res["documents"] == records_default["documents"] + + +def test_tenant_collections_add(client: Client) -> None: + client.reset() + + # Create two databases with same name in different tenants + admin_client = AdminClient.from_system(client._system) + admin_client.create_tenant("test_tenant1") + admin_client.create_tenant("test_tenant2") + admin_client.create_database("test_db", tenant="test_tenant1") + admin_client.create_database("test_db", tenant="test_tenant2") + + # Create collections in each database with same name + client.set_tenant(tenant="test_tenant1", database="test_db") + coll_tenant1 = client.create_collection("collection") + client.set_tenant(tenant="test_tenant2", database="test_db") + coll_tenant2 = client.create_collection("collection") + + records_tenant1 = { + "ids": ["a", "b", "c"], + "embeddings": [[1.0, 2.0, 3.0] for _ in range(3)], + "documents": ["a", "b", "c"], + } + + records_tenant2 = { + "ids": ["c", "d", "e"], + "embeddings": [[4.0, 5.0, 6.0] for _ in range(3)], + "documents": ["c", "d", "e"], + } + + # Add to the tenant1 coll + coll_tenant1.add(**records_tenant1) # type: ignore + + # Add to the tenant2 coll + coll_tenant2.add(**records_tenant2) # type: ignore + + # Make sure the collections are isolated + res = coll_tenant1.get(include=["embeddings", "documents"]) + assert res["ids"] == records_tenant1["ids"] + assert res["embeddings"] == records_tenant1["embeddings"] + assert res["documents"] == records_tenant1["documents"] + + res = coll_tenant2.get(include=["embeddings", "documents"]) + assert res["ids"] == records_tenant2["ids"] + assert res["embeddings"] == records_tenant2["embeddings"] + assert res["documents"] == records_tenant2["documents"] + + def test_min_len_name(client: Client) -> None: client.reset()