diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 04a55c01d53..a77515e1d99 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -31,7 +31,7 @@ from pypika.queries import QueryBuilder import pypika.functions as fn from pypika.terms import Criterion -from itertools import islice, groupby +from itertools import groupby from functools import reduce import sqlite3 @@ -121,6 +121,12 @@ def get_metadata( "embeddings", "embedding_metadata", "embedding_fulltext_search" ) + limit = limit or 2**63 - 1 + offset = offset or 0 + + if limit < 0: + raise ValueError("Limit cannot be negative") + q = ( ( self._db.querybuilder() @@ -138,26 +144,74 @@ def get_metadata( metadata_t.float_value, metadata_t.bool_value, ) - .where( - embeddings_t.segment_id == ParameterValue(self._db.uuid_to_db(self._id)) - ) - .orderby(embeddings_t.id) + .orderby(embeddings_t.embedding_id) ) - if where: - q = q.where(self._where_map_criterion(q, where, embeddings_t, metadata_t)) - if where_document: - q = q.where( - self._where_doc_criterion(q, where_document, embeddings_t, fulltext_t) + # If there is a query that touches the metadata table, it uses + # where and where_document filters, we treat this case seperately + if where is not None or where_document is not None: + metadata_q = ( + self._db.querybuilder() + .from_(metadata_t) + .select(metadata_t.id) + .join(embeddings_t) + .on(embeddings_t.id == metadata_t.id) + .orderby(embeddings_t.embedding_id) + .where( + embeddings_t.segment_id + == ParameterValue(self._db.uuid_to_db(self._id)) + ) + .distinct() # These are embedding ids ) - if ids: - q = q.where(embeddings_t.embedding_id.isin(ParameterValue(ids))) + if where: + metadata_q = metadata_q.where( + self._where_map_criterion( + metadata_q, where, metadata_t, embeddings_t + ) + ) + if where_document: + metadata_q = metadata_q.where( + self._where_doc_criterion( + metadata_q, where_document, metadata_t, fulltext_t, embeddings_t + ) + ) + if ids is not None: + metadata_q = metadata_q.where( + embeddings_t.embedding_id.isin(ParameterValue(ids)) + ) + + metadata_q = metadata_q.limit(limit) + metadata_q = metadata_q.offset(offset) + + q = q.where(embeddings_t.id.isin(metadata_q)) + else: + # In the case where we don't use the metadata table + # We have to apply limit/offset to embeddings and then join + # with metadata + embeddings_q = ( + self._db.querybuilder() + .from_(embeddings_t) + .select(embeddings_t.id) + .where( + embeddings_t.segment_id + == ParameterValue(self._db.uuid_to_db(self._id)) + ) + .orderby(embeddings_t.embedding_id) + .limit(limit) + .offset(offset) + ) + + if ids is not None: + embeddings_q = embeddings_q.where( + embeddings_t.embedding_id.isin(ParameterValue(ids)) + ) + + q = q.where(embeddings_t.id.isin(embeddings_q)) - limit = limit or 2**63 - 1 - offset = offset or 0 with self._db.tx() as cur: - return list(islice(self._records(cur, q), offset, offset + limit)) + # Execute the query with the limit and offset already applied + return list(self._records(cur, q)) def _records( self, cur: Cursor, q: QueryBuilder @@ -430,19 +484,19 @@ def _write_metadata(self, records: Sequence[EmbeddingRecord]) -> None: "SqliteMetadataSegment._where_map_criterion", OpenTelemetryGranularity.ALL ) def _where_map_criterion( - self, q: QueryBuilder, where: Where, embeddings_t: Table, metadata_t: Table + self, q: QueryBuilder, where: Where, metadata_t: Table, embeddings_t: Table ) -> Criterion: clause: list[Criterion] = [] for k, v in where.items(): if k == "$and": criteria = [ - self._where_map_criterion(q, w, embeddings_t, metadata_t) + self._where_map_criterion(q, w, metadata_t, embeddings_t) for w in cast(Sequence[Where], v) ] clause.append(reduce(lambda x, y: x & y, criteria)) elif k == "$or": criteria = [ - self._where_map_criterion(q, w, embeddings_t, metadata_t) + self._where_map_criterion(q, w, metadata_t, embeddings_t) for w in cast(Sequence[Where], v) ] clause.append(reduce(lambda x, y: x | y, criteria)) @@ -455,7 +509,7 @@ def _where_map_criterion( .where(metadata_t.key == ParameterValue(k)) .where(_where_clause(expr, metadata_t)) ) - clause.append(embeddings_t.id.isin(sq)) + clause.append(metadata_t.id.isin(sq)) return reduce(lambda x, y: x & y, clause) @trace_method( @@ -465,19 +519,24 @@ def _where_doc_criterion( self, q: QueryBuilder, where: WhereDocument, - embeddings_t: Table, + metadata_t: Table, fulltext_t: Table, + embeddings_t: Table, ) -> Criterion: for k, v in where.items(): if k == "$and": criteria = [ - self._where_doc_criterion(q, w, embeddings_t, fulltext_t) + self._where_doc_criterion( + q, w, metadata_t, fulltext_t, embeddings_t + ) for w in cast(Sequence[WhereDocument], v) ] return reduce(lambda x, y: x & y, criteria) elif k == "$or": criteria = [ - self._where_doc_criterion(q, w, embeddings_t, fulltext_t) + self._where_doc_criterion( + q, w, metadata_t, fulltext_t, embeddings_t + ) for w in cast(Sequence[WhereDocument], v) ] return reduce(lambda x, y: x | y, criteria) @@ -491,7 +550,7 @@ def _where_doc_criterion( .select(fulltext_t.rowid) .where(fulltext_t.string_value.like(ParameterValue(search_term))) ) - return embeddings_t.id.isin(sq) + return metadata_t.id.isin(sq) elif k == "$not_contains": v = cast(str, v) search_term = f"%{v}%" diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index d42213d76de..9129c023df7 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -125,7 +125,6 @@ def _filter_embedding_set( """Return IDs from the embedding set that match the given filter object""" normalized_record_set = invariants.wrap_all(record_set) - ids = set(normalized_record_set["ids"]) filter_ids = filter["ids"] @@ -154,6 +153,7 @@ def _filter_embedding_set( ) if not _filter_where_doc_clause(filter["where_document"], documents[i]): ids.discard(normalized_record_set["ids"][i]) + return list(ids) @@ -201,6 +201,52 @@ def test_filterable_metadata_get( assert sorted(result_ids) == sorted(expected_ids) +@settings( + suppress_health_check=[ + HealthCheck.function_scoped_fixture, + HealthCheck.large_base_example, + ] +) # type: ignore +@given( + collection=collection_st, + record_set=recordset_st, + filters=st.lists(strategies.filters(collection_st, recordset_st), min_size=1), + limit=st.integers(min_value=1, max_value=10), + offset=st.integers(min_value=0, max_value=10), +) +def test_filterable_metadata_get_limit_offset( + caplog, + api: ServerAPI, + collection: strategies.Collection, + record_set, + filters, + limit, + offset, +) -> None: + caplog.set_level(logging.ERROR) + + api.reset() + coll = api.create_collection( + name=collection.name, + metadata=collection.metadata, # type: ignore + embedding_function=collection.embedding_function, + ) + + if not invariants.is_metadata_valid(invariants.wrap_all(record_set)): + with pytest.raises(Exception): + coll.add(**record_set) + return + + coll.add(**record_set) + for filter in filters: + # add limit and offset to filter + filter["limit"] = limit + filter["offset"] = offset + result_ids = coll.get(**filter)["ids"] + expected_ids = _filter_embedding_set(record_set, filter) + assert sorted(result_ids) == sorted(expected_ids)[offset : offset + limit] + + @settings( suppress_health_check=[ HealthCheck.function_scoped_fixture, diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index b5704c2a503..ef6400b210e 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -121,6 +121,15 @@ def _build_document(i: int) -> str: metadata=None, ) +segment_definition2 = Segment( + id=uuid.uuid4(), + type="test_type", + scope=SegmentScope.METADATA, + topic="persistent://test/test/test_topic_2", + collection=None, + metadata=None, +) + def sync(segment: MetadataReader, seq_id: SeqId) -> None: # Try for up to 5 seconds, then throw a TimeoutError @@ -569,6 +578,50 @@ def _test_update( assert len(results) == 0 +def test_limit( + system: System, + sample_embeddings: Iterator[SubmitEmbeddingRecord], + produce_fns: ProducerFn, +) -> None: + producer = system.instance(Producer) + system.reset_state() + + topic = str(segment_definition["topic"]) + max_id = produce_fns(producer, topic, sample_embeddings, 3)[1][-1] + + topic2 = str(segment_definition2["topic"]) + max_id2 = produce_fns(producer, topic2, sample_embeddings, 3)[1][-1] + + segment = SqliteMetadataSegment(system, segment_definition) + segment.start() + + segment2 = SqliteMetadataSegment(system, segment_definition2) + segment2.start() + + sync(segment, max_id) + sync(segment2, max_id2) + + assert segment.count() == 3 + + for i in range(3): + max_id = producer.submit_embedding(topic, next(sample_embeddings)) + + sync(segment, max_id) + + assert segment.count() == 6 + + res = segment.get_metadata(limit=3) + assert len(res) == 3 + + # if limit is negative, throw error + with pytest.raises(ValueError): + segment.get_metadata(limit=-1) + + # if offset is more than number of results, return empty list + res = segment.get_metadata(limit=3, offset=10) + assert len(res) == 0 + + def test_delete_segment( system: System, sample_embeddings: Iterator[SubmitEmbeddingRecord], diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index bfb0dc6acce..d149299c4f8 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -5,4 +5,4 @@ export { IEmbeddingFunction } from './embeddings/IEmbeddingFunction'; export { OpenAIEmbeddingFunction } from './embeddings/OpenAIEmbeddingFunction'; export { CohereEmbeddingFunction } from './embeddings/CohereEmbeddingFunction'; export { IncludeEnum, GetParams } from './types'; -export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction'; \ No newline at end of file +export { HuggingFaceEmbeddingServerFunction } from './embeddings/HuggingFaceEmbeddingServerFunction';