-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7845ec8
commit 3cceed5
Showing
10 changed files
with
597 additions
and
639 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
import pytest | ||
import chromadb | ||
import traceback | ||
|
||
def test_ssl_self_signed(client_ssl): | ||
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"): | ||
pytest.skip("Skipping test for integration test") | ||
client_ssl.heartbeat() | ||
|
||
|
||
def test_ssl_self_signed_without_ssl_verify(client_ssl): | ||
if os.environ.get("CHROMA_INTEGRATION_TEST_ONLY"): | ||
pytest.skip("Skipping test for integration test") | ||
client_ssl.heartbeat() | ||
_port = client_ssl._server._settings.chroma_server_http_port | ||
with pytest.raises(ValueError) as e: | ||
chromadb.HttpClient(ssl=True, port=_port) | ||
stack_trace = traceback.format_exception( | ||
type(e.value), e.value, e.value.__traceback__ | ||
) | ||
client_ssl.clear_system_cache() | ||
assert "CERTIFICATE_VERIFY_FAILED" in "".join(stack_trace) | ||
|
||
# test get_version | ||
def test_get_version(client): | ||
client.reset() | ||
version = client.get_version() | ||
|
||
# assert version matches the pattern x.y.z | ||
import re | ||
|
||
assert re.match(r"\d+\.\d+\.\d+", version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import pytest | ||
import numpy as np | ||
|
||
from chromadb.test.api.utils import local_persist_api, batch_records,records | ||
from chromadb.errors import InvalidCollectionException | ||
|
||
@pytest.mark.parametrize("api_fixture", [local_persist_api]) | ||
def test_persist(api_fixture, request): | ||
client = request.getfixturevalue(api_fixture.__name__) | ||
|
||
client.reset() | ||
|
||
collection = client.create_collection("testspace") | ||
|
||
collection.add(**batch_records) | ||
|
||
assert collection.count() == 2 | ||
|
||
client = request.getfixturevalue(api_fixture.__name__) | ||
collection = client.get_collection("testspace") | ||
assert collection.count() == 2 | ||
|
||
client.delete_collection("testspace") | ||
|
||
client = request.getfixturevalue(api_fixture.__name__) | ||
assert client.list_collections() == [] | ||
|
||
def test_get_or_create(client): | ||
client.reset() | ||
|
||
collection = client.create_collection("testspace") | ||
|
||
collection.add(**batch_records) | ||
|
||
assert collection.count() == 2 | ||
|
||
with pytest.raises(Exception): | ||
collection = client.create_collection("testspace") | ||
|
||
collection = client.get_or_create_collection("testspace") | ||
|
||
assert collection.count() == 2 | ||
|
||
# test delete_collection | ||
def test_delete_collection(client): | ||
client.reset() | ||
collection = client.create_collection("test_delete_collection") | ||
collection.add(**records) | ||
|
||
assert len(client.list_collections()) == 1 | ||
client.delete_collection("test_delete_collection") | ||
assert len(client.list_collections()) == 0 | ||
|
||
def test_multiple_collections(client): | ||
embeddings1 = np.random.rand(10, 512).astype(np.float32).tolist() | ||
embeddings2 = np.random.rand(10, 512).astype(np.float32).tolist() | ||
ids1 = [f"http://example.com/1/{i}" for i in range(len(embeddings1))] | ||
ids2 = [f"http://example.com/2/{i}" for i in range(len(embeddings2))] | ||
|
||
client.reset() | ||
coll1 = client.create_collection("coll1") | ||
coll1.add(embeddings=embeddings1, ids=ids1) | ||
|
||
coll2 = client.create_collection("coll2") | ||
coll2.add(embeddings=embeddings2, ids=ids2) | ||
|
||
assert len(client.list_collections()) == 2 | ||
assert coll1.count() == len(embeddings1) | ||
assert coll2.count() == len(embeddings2) | ||
|
||
results1 = coll1.query(query_embeddings=embeddings1[0], n_results=1) | ||
results2 = coll2.query(query_embeddings=embeddings2[0], n_results=1) | ||
|
||
assert results1["ids"][0][0] == ids1[0] | ||
assert results2["ids"][0][0] == ids2[0] | ||
|
||
def test_collection_peek_with_invalid_collection_throws(client): | ||
client.reset() | ||
collection = client.create_collection("test") | ||
client.delete_collection("test") | ||
|
||
with pytest.raises( | ||
InvalidCollectionException, match=r"Collection .* does not exist." | ||
): | ||
collection.peek() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pytest | ||
from chromadb.test.test_api import contains_records | ||
|
||
# test to make sure delete error on invalid id input | ||
def test_delete_invalid_id(client): | ||
client.reset() | ||
collection = client.create_collection("test_invalid_id") | ||
|
||
# Delete with malformed ids | ||
with pytest.raises(ValueError) as e: | ||
collection.delete(ids=["valid", 0]) | ||
assert "ID" in str(e.value) | ||
|
||
|
||
def test_delete_where_document(client): | ||
client.reset() | ||
collection = client.create_collection("test_delete_where_document") | ||
collection.add(**contains_records) | ||
|
||
collection.delete(where_document={"$contains": "doc1"}) | ||
assert collection.count() == 1 | ||
|
||
collection.delete(where_document={"$contains": "bad"}) | ||
assert collection.count() == 1 | ||
|
||
collection.delete(where_document={"$contains": "great"}) | ||
assert collection.count() == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from typing import cast | ||
from chromadb.utils.embedding_functions import ( | ||
DefaultEmbeddingFunction, | ||
) | ||
from chromadb.api.types import EmbeddingFunction, Documents | ||
# test default embedding function | ||
|
||
def test_default_embedding(): | ||
embedding_function = cast(EmbeddingFunction[Documents], DefaultEmbeddingFunction()) | ||
docs = ["this is a test" for _ in range(64)] | ||
embeddings = embedding_function(docs) | ||
assert len(embeddings) == 64 |
Oops, something went wrong.