diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 514a239dced..44ba16bf826 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -213,8 +213,10 @@ def validate_embedding_function( "Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n" ) + L = TypeVar("L", covariant=True) + class DataLoader(Protocol[L]): def __call__(self, uris: URIs) -> L: ... @@ -404,7 +406,7 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument: f"Expected where document to have exactly one operator, got {where_document}" ) for operator, operand in where_document.items(): - if operator not in ["$contains", "$and", "$or"]: + if operator not in ["$contains", "$not_contains", "$and", "$or"]: raise ValueError( f"Expected where document operator to be one of $contains, $and, $or, got {operator}" ) diff --git a/chromadb/segment/impl/metadata/sqlite.py b/chromadb/segment/impl/metadata/sqlite.py index 01525ee8679..04a55c01d53 100644 --- a/chromadb/segment/impl/metadata/sqlite.py +++ b/chromadb/segment/impl/metadata/sqlite.py @@ -492,6 +492,19 @@ def _where_doc_criterion( .where(fulltext_t.string_value.like(ParameterValue(search_term))) ) return embeddings_t.id.isin(sq) + elif k == "$not_contains": + v = cast(str, v) + search_term = f"%{v}%" + + sq = ( + self._db.querybuilder() + .from_(fulltext_t) + .select(fulltext_t.rowid) + .where( + fulltext_t.string_value.not_like(ParameterValue(search_term)) + ) + ) + return embeddings_t.id.isin(sq) else: raise ValueError(f"Unknown where_doc operator {k}") raise ValueError("Empty where_doc") diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 142fbc8b3f2..af3cfd2c820 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -518,7 +518,13 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu word = draw(st.sampled_from(collection.known_document_keywords)) else: word = draw(safe_text) - return {"$contains": word} + + op: WhereOperator = draw(st.sampled_from(["$contains", "$not_contains"])) + if op == "$contains": + return {"$contains": word} + else: + assert op == "$not_contains" + return {"$not_contains": word} def binary_operator_clause( diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index b48d6e648b3..d42213d76de 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -94,12 +94,23 @@ def _filter_where_doc_clause(clause: WhereDocument, doc: Document) -> bool: # Simple $contains clause assert isinstance(expr, str) if key == "$contains": + if not doc: + return False # SQLite FTS handles % and _ as word boundaries that are ignored so we need to # treat them as wildcards if "%" in expr or "_" in expr: expr = expr.replace("%", ".").replace("_", ".") return re.search(expr, doc) is not None return expr in doc + elif key == "$not_contains": + if not doc: + return False + # SQLite FTS handles % and _ as word boundaries that are ignored so we need to + # treat them as wildcards + if "%" in expr or "_" in expr: + expr = expr.replace("%", ".").replace("_", ".") + return re.search(expr, doc) is None + return expr not in doc else: raise ValueError("Unknown operator: {}".format(key)) @@ -118,6 +129,7 @@ def _filter_embedding_set( ids = set(normalized_record_set["ids"]) filter_ids = filter["ids"] + if filter_ids is not None: filter_ids = invariants.wrap(filter_ids) assert filter_ids is not None diff --git a/chromadb/test/segment/test_metadata.py b/chromadb/test/segment/test_metadata.py index 95806a53965..b5704c2a503 100644 --- a/chromadb/test/segment/test_metadata.py +++ b/chromadb/test/segment/test_metadata.py @@ -296,10 +296,22 @@ def test_fulltext( result = segment.get_metadata(where_document={"$contains": "four two"}) assert len(result) == 1 + # Test not_contains + result = segment.get_metadata(where_document={"$not_contains": "four two"}) + assert len(result) == len( + [i for i in range(1, 100) if "four two" not in _build_document(i)] + ) + # Test many results result = segment.get_metadata(where_document={"$contains": "zero"}) assert len(result) == 9 + # Test not_contains + result = segment.get_metadata(where_document={"$not_contains": "zero"}) + assert len(result) == len( + [i for i in range(1, 100) if "zero" not in _build_document(i)] + ) + # test $and result = segment.get_metadata( where_document={"$and": [{"$contains": "four"}, {"$contains": "two"}]} @@ -307,6 +319,17 @@ def test_fulltext( assert len(result) == 2 assert set([r["id"] for r in result]) == {"embedding_42", "embedding_24"} + result = segment.get_metadata( + where_document={"$and": [{"$not_contains": "four"}, {"$not_contains": "two"}]} + ) + assert len(result) == len( + [ + i + for i in range(1, 100) + if "four" not in _build_document(i) and "two" not in _build_document(i) + ] + ) + # test $or result = segment.get_metadata( where_document={"$or": [{"$contains": "zero"}, {"$contains": "one"}]} @@ -316,6 +339,17 @@ def test_fulltext( expected = set([f"embedding_{i}" for i in set(ones + zeros)]) assert set([r["id"] for r in result]) == expected + result = segment.get_metadata( + where_document={"$or": [{"$not_contains": "zero"}, {"$not_contains": "one"}]} + ) + assert len(result) == len( + [ + i + for i in range(1, 100) + if "zero" not in _build_document(i) or "one" not in _build_document(i) + ] + ) + # test combo with where clause (negative case) result = segment.get_metadata( where={"int_key": {"$eq": 42}}, where_document={"$contains": "zero"} diff --git a/chromadb/types.py b/chromadb/types.py index 2478a1d1bb5..db22417ace1 100644 --- a/chromadb/types.py +++ b/chromadb/types.py @@ -156,7 +156,9 @@ class VectorQueryResult(TypedDict): Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]] ] -WhereDocumentOperator = Union[Literal["$contains"], LogicalOperator] +WhereDocumentOperator = Union[ + Literal["$contains"], Literal["$not_contains"], LogicalOperator +] WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]] diff --git a/examples/basic_functionality/where_filtering.ipynb b/examples/basic_functionality/where_filtering.ipynb index e7e606b3df1..89c951f8873 100644 --- a/examples/basic_functionality/where_filtering.ipynb +++ b/examples/basic_functionality/where_filtering.ipynb @@ -87,7 +87,9 @@ "{'ids': ['id7'],\n", " 'embeddings': None,\n", " 'metadatas': [{'status': 'read'}],\n", - " 'documents': ['A document that discusses international affairs']}" + " 'documents': ['A document that discusses international affairs'],\n", + " 'uris': None,\n", + " 'data': None}" ] }, "execution_count": 5, @@ -112,7 +114,9 @@ " 'embeddings': None,\n", " 'metadatas': [{'status': 'read'}, {'status': 'unread'}],\n", " 'documents': ['A document that discusses domestic policy',\n", - " 'A document that discusses global affairs']}" + " 'A document that discusses global affairs'],\n", + " 'uris': None,\n", + " 'data': None}" ] }, "execution_count": 6, @@ -141,7 +145,9 @@ " 'embeddings': None,\n", " 'documents': [['A document that discusses international affairs',\n", " 'A document that discusses international affairs',\n", - " 'A document that discusses global affairs']]}" + " 'A document that discusses global affairs']],\n", + " 'uris': None,\n", + " 'data': None}" ] }, "execution_count": 7, @@ -155,8 +161,49 @@ "collection.query(query_embeddings=[[0, 0, 0]], where_document={\"$contains\": \"affairs\"}, n_results=5)" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'ids': [['id5', 'id3', 'id7', 'id8', 'id4']],\n", + " 'distances': [[16.740001678466797,\n", + " 16.740001678466797,\n", + " 16.740001678466797,\n", + " 87.22000122070312,\n", + " 87.22000122070312]],\n", + " 'metadatas': [[{'status': 'read'},\n", + " {'status': 'read'},\n", + " {'status': 'read'},\n", + " {'status': 'unread'},\n", + " {'status': 'unread'}]],\n", + " 'embeddings': None,\n", + " 'documents': [['A document that discusses chocolate',\n", + " 'A document that discusses kittens',\n", + " 'A document that discusses international affairs',\n", + " 'A document that discusses global affairs',\n", + " 'A document that discusses dogs']],\n", + " 'uris': None,\n", + " 'data': None}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "collection.query(query_embeddings=[[0, 0, 0]], where_document={\"$not_contains\": \"domestic policy\"}, n_results=5)" + ] + }, { "cell_type": "markdown", + "metadata": { + "collapsed": false + }, "source": [ "# Where Filtering With Logical Operators\n", "This section demonstrates how one can use the logical operators in `where` filtering.\n", @@ -164,27 +211,31 @@ "Chroma currently supports: `$and` and `$or`operators.\n", "\n", "> Note: Logical operators can be nested" - ], - "metadata": { - "collapsed": false - } + ] }, { "cell_type": "code", - "execution_count": 8, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tazarov/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [02:59<00:00, 463kiB/s] \n" - ] + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-11T18:45:52.663345Z", + "start_time": "2023-08-11T18:42:50.970414Z" }, + "collapsed": false + }, + "outputs": [ { "data": { - "text/plain": "{'ids': ['1', '2'],\n 'embeddings': None,\n 'metadatas': [{'author': 'john'}, {'author': 'jack'}],\n 'documents': ['Article by john', 'Article by Jack']}" + "text/plain": [ + "{'ids': ['1', '2'],\n", + " 'embeddings': None,\n", + " 'metadatas': [{'author': 'john'}, {'author': 'jack'}],\n", + " 'documents': ['Article by john', 'Article by Jack'],\n", + " 'uris': None,\n", + " 'data': None}" + ] }, - "execution_count": 8, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -198,24 +249,31 @@ " metadatas=[{\"author\": \"john\"}, {\"author\": \"jack\"}, {\"author\": \"jill\"}], ids=[\"1\", \"2\", \"3\"])\n", "\n", "collection.get(where={\"$or\": [{\"author\": \"john\"}, {\"author\": \"jack\"}]})\n" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-08-11T18:45:52.663345Z", - "start_time": "2023-08-11T18:42:50.970414Z" - } - } + ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-11T18:49:31.174811Z", + "start_time": "2023-08-11T18:49:31.056618Z" + }, + "collapsed": false + }, "outputs": [ { "data": { - "text/plain": "{'ids': ['1'],\n 'embeddings': None,\n 'metadatas': [{'author': 'john', 'category': 'chroma'}],\n 'documents': ['Article by john']}" + "text/plain": [ + "{'ids': ['1'],\n", + " 'embeddings': None,\n", + " 'metadatas': [{'author': 'john', 'category': 'chroma'}],\n", + " 'documents': ['Article by john'],\n", + " 'uris': None,\n", + " 'data': None}" + ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -226,24 +284,31 @@ "collection.upsert(documents=[\"Article by john\", \"Article by Jack\", \"Article by Jill\"],\n", " metadatas=[{\"author\": \"john\",\"category\":\"chroma\"}, {\"author\": \"jack\",\"category\":\"ml\"}, {\"author\": \"jill\",\"category\":\"lifestyle\"}], ids=[\"1\", \"2\", \"3\"])\n", "collection.get(where={\"$and\": [{\"category\": \"chroma\"}, {\"author\": \"john\"}]})" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-08-11T18:49:31.174811Z", - "start_time": "2023-08-11T18:49:31.056618Z" - } - } + ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-11T18:49:35.758816Z", + "start_time": "2023-08-11T18:49:35.741477Z" + }, + "collapsed": false + }, "outputs": [ { "data": { - "text/plain": "{'ids': [], 'embeddings': None, 'metadatas': [], 'documents': []}" + "text/plain": [ + "{'ids': [],\n", + " 'embeddings': None,\n", + " 'metadatas': [],\n", + " 'documents': [],\n", + " 'uris': None,\n", + " 'data': None}" + ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -251,24 +316,31 @@ "source": [ "# And logical that doesn't match anything\n", "collection.get(where={\"$and\": [{\"category\": \"chroma\"}, {\"author\": \"jill\"}]})" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-08-11T18:49:35.758816Z", - "start_time": "2023-08-11T18:49:35.741477Z" - } - } + ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-11T18:49:40.463045Z", + "start_time": "2023-08-11T18:49:40.450240Z" + }, + "collapsed": false + }, "outputs": [ { "data": { - "text/plain": "{'ids': ['1'],\n 'embeddings': None,\n 'metadatas': [{'author': 'john', 'category': 'chroma'}],\n 'documents': ['Article by john']}" + "text/plain": [ + "{'ids': ['1'],\n", + " 'embeddings': None,\n", + " 'metadatas': [{'author': 'john', 'category': 'chroma'}],\n", + " 'documents': ['Article by john'],\n", + " 'uris': None,\n", + " 'data': None}" + ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -276,22 +348,29 @@ "source": [ "# Combined And and Or Logical Operator Filtering\n", "collection.get(where={\"$and\": [{\"category\": \"chroma\"}, {\"$or\": [{\"author\": \"john\"}, {\"author\": \"jack\"}]}]})" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-08-11T18:49:40.463045Z", - "start_time": "2023-08-11T18:49:40.450240Z" - } - } + ] }, { "cell_type": "code", "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-11T18:51:12.328062Z", + "start_time": "2023-08-11T18:51:12.315943Z" + }, + "collapsed": false + }, "outputs": [ { "data": { - "text/plain": "{'ids': ['1'],\n 'embeddings': None,\n 'metadatas': [{'author': 'john', 'category': 'chroma'}],\n 'documents': ['Article by john']}" + "text/plain": [ + "{'ids': ['1'],\n", + " 'embeddings': None,\n", + " 'metadatas': [{'author': 'john', 'category': 'chroma'}],\n", + " 'documents': ['Article by john'],\n", + " 'uris': None,\n", + " 'data': None}" + ] }, "execution_count": 13, "metadata": {}, @@ -300,23 +379,16 @@ ], "source": [ "collection.get(where_document={\"$contains\": \"Article\"},where={\"$and\": [{\"category\": \"chroma\"}, {\"$or\": [{\"author\": \"john\"}, {\"author\": \"jack\"}]}]})" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-08-11T18:51:12.328062Z", - "start_time": "2023-08-11T18:51:12.315943Z" - } - } + ] }, { "cell_type": "code", "execution_count": null, - "outputs": [], - "source": [], "metadata": { "collapsed": false - } + }, + "outputs": [], + "source": [] } ], "metadata": { @@ -335,7 +407,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.11.6" }, "orig_nbformat": 4, "vscode": {