Skip to content

Commit

Permalink
modify ann accuracy and count
Browse files Browse the repository at this point in the history
  • Loading branch information
spikechroma committed Sep 11, 2024
1 parent 15c40e6 commit a70ae76
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 40 deletions.
6 changes: 4 additions & 2 deletions chromadb/test/distributed/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_add(
"metadatas": None,
"documents": None,
},
10,
n_records=len(ids),
n_results=10,
query_embeddings=[random_query],
)

Expand Down Expand Up @@ -92,6 +93,7 @@ def test_add_include_all_with_compaction_delay(client: ClientAPI) -> None:
"metadatas": None,
"documents": documents,
},
10,
n_records=len(ids),
n_results=10,
query_embeddings=[random_query_1, random_query_2],
)
34 changes: 23 additions & 11 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ def wrap_all(record_set: RecordSet) -> NormalizedRecordSet:
def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) -> int:
normalized_record_set = wrap_all(cast(RecordSet, state_record_set))

# we need to replace empty lists with None to use get_n_items_from_record_set
# we need to replace empty lists with None within the record set state to use get_n_items_from_record_set
# get_n_items_from_record_set would throw an error if it encounters an empty list
record_set_with_empty_lists_replaced: NormalizedRecordSet = {
record_set_with_empty_lists_replaced: types.RecordSet = {
"ids": None,
"documents": None,
"metadatas": None,
"embeddings": None,
"images": None,
"uris": None,
}

all_fields_are_empty = True
Expand All @@ -100,11 +102,14 @@ def get_n_items_from_record_set_state(state_record_set: StateMachineRecordSet) -
if all_fields_are_empty:
return 0

return get_n_items_from_record_set(record_set_with_empty_lists_replaced)
(_, n) = types.get_n_items_from_record_set(record_set_with_empty_lists_replaced)
return n


def get_n_items_from_record_set(normalized_record_set: NormalizedRecordSet) -> int:
def get_n_items_from_record_set(record_set: RecordSet) -> int:
"""Get the number of items from a record set"""
normalized_record_set = wrap_all(record_set)

(_, n) = types.get_n_items_from_record_set(
{
"ids": normalized_record_set["ids"],
Expand All @@ -122,9 +127,16 @@ def get_n_items_from_record_set(normalized_record_set: NormalizedRecordSet) -> i
def count(collection: Collection, record_set: RecordSet) -> None:
"""The given collection count is equal to the number of embeddings"""
count = collection.count()
normalized_record_set = wrap_all(record_set)
n = get_n_items_from_record_set(record_set)
assert count == n

n = get_n_items_from_record_set(normalized_record_set)

def count_state_record_set(
collection: Collection, record_set: StateMachineRecordSet
) -> None:
"""The given collection count is equal to the number of embeddings within the state record set"""
count = collection.count()
n = get_n_items_from_record_set_state(record_set)
assert count == n


Expand Down Expand Up @@ -201,7 +213,7 @@ def metadatas_match(collection: Collection, record_set: RecordSet) -> None:
collection,
normalized_record_set,
"metadatas",
get_n_items_from_record_set(normalized_record_set),
get_n_items_from_record_set(record_set),
)


Expand All @@ -226,7 +238,7 @@ def documents_match(collection: Collection, record_set: RecordSet) -> None:
collection,
normalized_record_set,
"documents",
get_n_items_from_record_set(normalized_record_set),
get_n_items_from_record_set(record_set),
)


Expand All @@ -251,7 +263,7 @@ def embeddings_match(collection: Collection, record_set: RecordSet) -> None:
collection,
normalized_record_set,
"embeddings",
get_n_items_from_record_set(normalized_record_set),
get_n_items_from_record_set(record_set),
)


Expand Down Expand Up @@ -321,6 +333,7 @@ def fd_not_exceeding_threadpool_size(threadpool_size: int) -> None:
def ann_accuracy(
collection: Collection,
record_set: RecordSet,
n_records: int,
n_results: int = 1,
min_recall: float = 0.99,
embedding_function: Optional[types.EmbeddingFunction] = None, # type: ignore[type-arg]
Expand All @@ -330,8 +343,7 @@ def ann_accuracy(
"""Validate that the API performs nearest_neighbor searches correctly"""
normalized_record_set = wrap_all(record_set)

n = get_n_items_from_record_set(normalized_record_set)
if n == 0:
if n_records == 0:
return # nothing to test here

if normalized_record_set["ids"] is None:
Expand Down
8 changes: 5 additions & 3 deletions chromadb/test/property/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def _test_add(
# TODO: The type of add() is incorrect as it does not allow for metadatas
# like [{"a": 1}, None, {"a": 3}]
result = coll.add(**record_set) # type: ignore
if result["ids"] is not None:
if normalized_record_set["ids"] is not None:
normalized_record_set["ids"] = result["ids"]

n_records = invariants.get_n_items_from_record_set(normalized_record_set)
n_records = invariants.get_n_items_from_record_set(record_set)

# Only wait for compaction if the size of the collection is
# some minimal size
Expand All @@ -128,6 +128,7 @@ def _test_add(
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_records=n_records,
n_results=n_results,
embedding_function=collection.embedding_function,
query_indices=list(range(i, min(i + batch_size, n_records))),
Expand All @@ -136,6 +137,7 @@ def _test_add(
invariants.ann_accuracy(
coll,
cast(strategies.RecordSet, normalized_record_set),
n_records=n_records,
n_results=n_results,
embedding_function=collection.embedding_function,
)
Expand Down Expand Up @@ -193,7 +195,7 @@ def test_add_large(
if results["ids"] is None:
raise ValueError("IDs should not be None")

n_records = invariants.get_n_items_from_record_set(normalized_record_set)
n_records = invariants.get_n_items_from_record_set(record_set)
if not NOT_CLUSTER_ONLY and should_compact and n_records > 10:
# Wait for the model to be updated, since the record set is larger, add some additional time
wait_for_version_increase(
Expand Down
6 changes: 5 additions & 1 deletion chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,11 @@ def test_cycle_versions(
invariants.metadatas_match(coll, embeddings_strategy)
invariants.documents_match(coll, embeddings_strategy)
invariants.ids_match(coll, embeddings_strategy)
invariants.ann_accuracy(coll, embeddings_strategy)
invariants.ann_accuracy(
coll,
embeddings_strategy,
n_records=invariants.get_n_items_from_record_set(embeddings_strategy),
)
invariants.log_size_below_max(system, [coll], True)

# Shutdown system
Expand Down
22 changes: 6 additions & 16 deletions chromadb/test/property/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,9 @@ def upsert_embeddings(self, record_set: strategies.RecordSet) -> None:

@invariant()
def count(self) -> None:
n = invariants.get_n_items_from_record_set_state(self.record_set_state)
if n == 0:
assert self.collection.count() == 0
return

invariants.count(
invariants.count_state_record_set(
self.collection,
# this cast exists because count function is used for handling both StateMachineRecordSet and RecordSet
cast(strategies.RecordSet, self.record_set_state),
self.record_set_state,
)

@invariant()
Expand All @@ -229,11 +223,10 @@ def no_duplicates(self) -> None:

@invariant()
def ann_accuracy(self) -> None:
n = invariants.get_n_items_from_record_set_state(self.record_set_state)
if n == 0:
return

invariants.ann_accuracy(
n_records=invariants.get_n_items_from_record_set_state(
self.record_set_state
),
collection=self.collection,
record_set=cast(strategies.RecordSet, self.record_set_state),
min_recall=0.95,
Expand Down Expand Up @@ -411,11 +404,8 @@ def wait_for_compaction(self) -> None:
)
def add_embeddings(self, record_set: strategies.RecordSet) -> MultipleResults[ID]:
res = super().add_embeddings(record_set)
normalized_record_set: strategies.NormalizedRecordSet = invariants.wrap_all(
record_set
)

n_records = invariants.get_n_items_from_record_set(normalized_record_set)
n_records = invariants.get_n_items_from_record_set(record_set)
ids = [id for id in res]

print(
Expand Down
8 changes: 4 additions & 4 deletions chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def test_persist(
coll,
embeddings_strategy,
embedding_function=collection_strategy.embedding_function,
n_records=invariants.get_n_items_from_record_set(embeddings_strategy),
)

system_1.stop()
Expand All @@ -126,6 +127,7 @@ def test_persist(
coll,
embeddings_strategy,
embedding_function=collection_strategy.embedding_function,
n_records=invariants.get_n_items_from_record_set(embeddings_strategy),
)

system_2.stop()
Expand Down Expand Up @@ -207,13 +209,11 @@ def load_and_check(
name=collection_name,
embedding_function=strategies.not_implemented_embedding_function(), # type: ignore[arg-type]
)
invariants.count(coll, record_set)
invariants.count_state_record_set(coll, record_set) # type: ignore[arg-type]
invariants.metadatas_match_state_record_set(coll, record_set) # type: ignore[arg-type]
invariants.documents_match_state_record_set(coll, record_set) # type: ignore[arg-type]
invariants.ids_match(coll, record_set)

if invariants.get_n_items_from_record_set_state(record_set) > 0: # type: ignore[arg-type]
invariants.ann_accuracy(coll, record_set)
invariants.ann_accuracy(coll, record_set, n_records=invariants.get_n_items_from_record_set_state(record_set)) # type: ignore[arg-type]

system.stop()
except Exception as e:
Expand Down
8 changes: 5 additions & 3 deletions chromadb/test/test_multithreaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ def _test_multithreaded_add(
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures: List[Future[Any]] = []
total_sent = -1
while total_sent < len(ids):
while total_sent < len(ids): # type: ignore[arg-type]
# Randomly grab up to 10% of the dataset and send it to the executor
batch_size = random.randint(1, N // 10)
to_send = min(batch_size, len(ids) - total_sent)
to_send = min(batch_size, len(ids) - total_sent) # type: ignore[arg-type]
start = total_sent + 1
end = total_sent + to_send + 1
if embeddings is not None and len(embeddings[start:end]) == 0:
break
future = executor.submit(
coll.add,
ids=ids[start:end],
ids=ids[start:end], # type: ignore[index]
embeddings=embeddings[start:end] if embeddings is not None else None,
metadatas=metadatas[start:end] if metadatas is not None else None, # type: ignore
documents=documents[start:end] if documents is not None else None,
Expand Down Expand Up @@ -93,6 +93,7 @@ def _test_multithreaded_add(
invariants.ann_accuracy(
coll,
records_set,
n_records=invariants.get_n_items_from_record_set(records_set),
n_results=n_results,
query_indices=query_indices,
)
Expand Down Expand Up @@ -210,6 +211,7 @@ def perform_operation(
invariants.ann_accuracy(
coll,
records_set,
n_records=invariants.get_n_items_from_record_set(records_set),
n_results=n_results,
query_indices=query_indices,
)
Expand Down

0 comments on commit a70ae76

Please sign in to comment.