Skip to content

Commit

Permalink
[ENH] Fix for record pagination (#1450)
Browse files Browse the repository at this point in the history
This fixes the pagination on records. Before we would select all data,
and then subsample - this obviously is not very performant and this
fixes it. credit to @HammadB


- [x] Fix bug with limit first
- [x] add tests
  • Loading branch information
jeffchuber authored Dec 12, 2023
1 parent a1e6e01 commit 16d3fe9
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 25 deletions.
105 changes: 82 additions & 23 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pypika.queries import QueryBuilder
import pypika.functions as fn
from pypika.terms import Criterion
from itertools import islice, groupby
from itertools import groupby
from functools import reduce
import sqlite3

Expand Down Expand Up @@ -121,6 +121,12 @@ def get_metadata(
"embeddings", "embedding_metadata", "embedding_fulltext_search"
)

limit = limit or 2**63 - 1
offset = offset or 0

if limit < 0:
raise ValueError("Limit cannot be negative")

q = (
(
self._db.querybuilder()
Expand All @@ -138,26 +144,74 @@ def get_metadata(
metadata_t.float_value,
metadata_t.bool_value,
)
.where(
embeddings_t.segment_id == ParameterValue(self._db.uuid_to_db(self._id))
)
.orderby(embeddings_t.id)
.orderby(embeddings_t.embedding_id)
)

if where:
q = q.where(self._where_map_criterion(q, where, embeddings_t, metadata_t))
if where_document:
q = q.where(
self._where_doc_criterion(q, where_document, embeddings_t, fulltext_t)
# If there is a query that touches the metadata table, it uses
# where and where_document filters, we treat this case seperately
if where is not None or where_document is not None:
metadata_q = (
self._db.querybuilder()
.from_(metadata_t)
.select(metadata_t.id)
.join(embeddings_t)
.on(embeddings_t.id == metadata_t.id)
.orderby(embeddings_t.embedding_id)
.where(
embeddings_t.segment_id
== ParameterValue(self._db.uuid_to_db(self._id))
)
.distinct() # These are embedding ids
)

if ids:
q = q.where(embeddings_t.embedding_id.isin(ParameterValue(ids)))
if where:
metadata_q = metadata_q.where(
self._where_map_criterion(
metadata_q, where, metadata_t, embeddings_t
)
)
if where_document:
metadata_q = metadata_q.where(
self._where_doc_criterion(
metadata_q, where_document, metadata_t, fulltext_t, embeddings_t
)
)
if ids is not None:
metadata_q = metadata_q.where(
embeddings_t.embedding_id.isin(ParameterValue(ids))
)

metadata_q = metadata_q.limit(limit)
metadata_q = metadata_q.offset(offset)

q = q.where(embeddings_t.id.isin(metadata_q))
else:
# In the case where we don't use the metadata table
# We have to apply limit/offset to embeddings and then join
# with metadata
embeddings_q = (
self._db.querybuilder()
.from_(embeddings_t)
.select(embeddings_t.id)
.where(
embeddings_t.segment_id
== ParameterValue(self._db.uuid_to_db(self._id))
)
.orderby(embeddings_t.embedding_id)
.limit(limit)
.offset(offset)
)

if ids is not None:
embeddings_q = embeddings_q.where(
embeddings_t.embedding_id.isin(ParameterValue(ids))
)

q = q.where(embeddings_t.id.isin(embeddings_q))

limit = limit or 2**63 - 1
offset = offset or 0
with self._db.tx() as cur:
return list(islice(self._records(cur, q), offset, offset + limit))
# Execute the query with the limit and offset already applied
return list(self._records(cur, q))

def _records(
self, cur: Cursor, q: QueryBuilder
Expand Down Expand Up @@ -430,19 +484,19 @@ def _write_metadata(self, records: Sequence[EmbeddingRecord]) -> None:
"SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL
)
def _where_map_criterion(
self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table
self, q: QueryBuilder, where: Where, metadata_t: Table, embeddings_t: Table
) -> Criterion:
clause: list[Criterion] = []
for k, v in where.items():
if k == "$and":
criteria = [
self._where_map_criterion(q, w, embeddings_t, metadata_t)
self._where_map_criterion(q, w, metadata_t, embeddings_t)
for w in cast(Sequence[Where], v)
]
clause.append(reduce(lambda x, y: x & y, criteria))
elif k == "$or":
criteria = [
self._where_map_criterion(q, w, embeddings_t, metadata_t)
self._where_map_criterion(q, w, metadata_t, embeddings_t)
for w in cast(Sequence[Where], v)
]
clause.append(reduce(lambda x, y: x | y, criteria))
Expand All @@ -455,7 +509,7 @@ def _where_map_criterion(
.where(metadata_t.key == ParameterValue(k))
.where(_where_clause(expr, metadata_t))
)
clause.append(embeddings_t.id.isin(sq))
clause.append(metadata_t.id.isin(sq))
return reduce(lambda x, y: x & y, clause)

@trace_method(
Expand All @@ -465,19 +519,24 @@ def _where_doc_criterion(
self,
q: QueryBuilder,
where: WhereDocument,
embeddings_t: Table,
metadata_t: Table,
fulltext_t: Table,
embeddings_t: Table,
) -> Criterion:
for k, v in where.items():
if k == "$and":
criteria = [
self._where_doc_criterion(q, w, embeddings_t, fulltext_t)
self._where_doc_criterion(
q, w, metadata_t, fulltext_t, embeddings_t
)
for w in cast(Sequence[WhereDocument], v)
]
return reduce(lambda x, y: x & y, criteria)
elif k == "$or":
criteria = [
self._where_doc_criterion(q, w, embeddings_t, fulltext_t)
self._where_doc_criterion(
q, w, metadata_t, fulltext_t, embeddings_t
)
for w in cast(Sequence[WhereDocument], v)
]
return reduce(lambda x, y: x | y, criteria)
Expand All @@ -491,7 +550,7 @@ def _where_doc_criterion(
.select(fulltext_t.rowid)
.where(fulltext_t.string_value.like(ParameterValue(search_term)))
)
return embeddings_t.id.isin(sq)
return metadata_t.id.isin(sq)
elif k == "$not_contains":
v = cast(str, v)
search_term = f"%{v}%"
Expand Down
48 changes: 47 additions & 1 deletion chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def _filter_embedding_set(
"""Return IDs from the embedding set that match the given filter object"""

normalized_record_set = invariants.wrap_all(record_set)

ids = set(normalized_record_set["ids"])

filter_ids = filter["ids"]
Expand Down Expand Up @@ -154,6 +153,7 @@ def _filter_embedding_set(
)
if not _filter_where_doc_clause(filter["where_document"], documents[i]):
ids.discard(normalized_record_set["ids"][i])

return list(ids)


Expand Down Expand Up @@ -201,6 +201,52 @@ def test_filterable_metadata_get(
assert sorted(result_ids) == sorted(expected_ids)


@settings(
suppress_health_check=[
HealthCheck.function_scoped_fixture,
HealthCheck.large_base_example,
]
) # type: ignore
@given(
collection=collection_st,
record_set=recordset_st,
filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1),
limit=st.integers(min_value=1, max_value=10),
offset=st.integers(min_value=0, max_value=10),
)
def test_filterable_metadata_get_limit_offset(
caplog,
api: ServerAPI,
collection: strategies.Collection,
record_set,
filters,
limit,
offset,
) -> None:
caplog.set_level(logging.ERROR)

api.reset()
coll = api.create_collection(
name=collection.name,
metadata=collection.metadata, # type: ignore
embedding_function=collection.embedding_function,
)

if not invariants.is_metadata_valid(invariants.wrap_all(record_set)):
with pytest.raises(Exception):
coll.add(**record_set)
return

coll.add(**record_set)
for filter in filters:
# add limit and offset to filter
filter["limit"] = limit
filter["offset"] = offset
result_ids = coll.get(**filter)["ids"]
expected_ids = _filter_embedding_set(record_set, filter)
assert sorted(result_ids) == sorted(expected_ids)[offset : offset + limit]


@settings(
suppress_health_check=[
HealthCheck.function_scoped_fixture,
Expand Down
53 changes: 53 additions & 0 deletions chromadb/test/segment/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ def _build_document(i: int) -> str:
metadata=None,
)

segment_definition2 = Segment(
id=uuid.uuid4(),
type="test_type",
scope=SegmentScope.METADATA,
topic="persistent://test/test/test_topic_2",
collection=None,
metadata=None,
)


def sync(segment: MetadataReader, seq_id: SeqId) -> None:
# Try for up to 5 seconds, then throw a TimeoutError
Expand Down Expand Up @@ -569,6 +578,50 @@ def _test_update(
assert len(results) == 0


def test_limit(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()

topic = str(segment_definition["topic"])
max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1]

topic2 = str(segment_definition2["topic"])
max_id2 = produce_fns(producer, topic2, sample_embeddings, 3)[1][-1]

segment = SqliteMetadataSegment(system, segment_definition)
segment.start()

segment2 = SqliteMetadataSegment(system, segment_definition2)
segment2.start()

sync(segment, max_id)
sync(segment2, max_id2)

assert segment.count() == 3

for i in range(3):
max_id = producer.submit_embedding(topic, next(sample_embeddings))

sync(segment, max_id)

assert segment.count() == 6

res = segment.get_metadata(limit=3)
assert len(res) == 3

# if limit is negative, throw error
with pytest.raises(ValueError):
segment.get_metadata(limit=-1)

# if offset is more than number of results, return empty list
res = segment.get_metadata(limit=3, offset=10)
assert len(res) == 0


def test_delete_segment(
system: System,
sample_embeddings: Iterator[SubmitEmbeddingRecord],
Expand Down
2 changes: 1 addition & 1 deletion clients/js/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ export { IEmbeddingFunction } from './embeddings/IEmbeddingFunction';
export { OpenAIEmbeddingFunction } from './embeddings/OpenAIEmbeddingFunction';
export { CohereEmbeddingFunction } from './embeddings/CohereEmbeddingFunction';
export { IncludeEnum, GetParams } from './types';
export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction';
export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction';

0 comments on commit 16d3fe9

Please sign in to comment.