Skip to content

Commit

Permalink
[ENH] add not contains filter to where clause (#1469)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Fixes #1082

There are some formatting changes included

## Test plan
- [ ] test_filtering.py and test_metadata.py

## Documentation Changes
will update [docs]user-guide
added an example to Where Filtering

---------
authored-by: Weili Gu <[email protected]>
  • Loading branch information
weiligu authored Dec 6, 2023
1 parent 8adb20a commit 88db738
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 74 deletions.
4 changes: 3 additions & 1 deletion chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,10 @@ def validate_embedding_function(
"Please note the recent change to the EmbeddingFunction interface: https://docs.trychroma.com/migration#migration-to-0416---november-7-2023 \n"
)


L = TypeVar("L", covariant=True)


class DataLoader(Protocol[L]):
def __call__(self, uris: URIs) -> L:
...
Expand Down Expand Up @@ -404,7 +406,7 @@ def validate_where_document(where_document: WhereDocument) -> WhereDocument:
f"Expected where document to have exactly one operator, got {where_document}"
)
for operator, operand in where_document.items():
if operator not in ["$contains", "$and", "$or"]:
if operator not in ["$contains", "$not_contains", "$and", "$or"]:
raise ValueError(
f"Expected where document operator to be one of $contains, $and, $or, got {operator}"
)
Expand Down
13 changes: 13 additions & 0 deletions chromadb/segment/impl/metadata/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,19 @@ def _where_doc_criterion(
.where(fulltext_t.string_value.like(ParameterValue(search_term)))
)
return embeddings_t.id.isin(sq)
elif k == "$not_contains":
v = cast(str, v)
search_term = f"%{v}%"

sq = (
self._db.querybuilder()
.from_(fulltext_t)
.select(fulltext_t.rowid)
.where(
fulltext_t.string_value.not_like(ParameterValue(search_term))
)
)
return embeddings_t.id.isin(sq)
else:
raise ValueError(f"Unknown where_doc operator {k}")
raise ValueError("Empty where_doc")
Expand Down
8 changes: 7 additions & 1 deletion chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,13 @@ def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocu
word = draw(st.sampled_from(collection.known_document_keywords))
else:
word = draw(safe_text)
return {"$contains": word}

op: WhereOperator = draw(st.sampled_from(["$contains", "$not_contains"]))
if op == "$contains":
return {"$contains": word}
else:
assert op == "$not_contains"
return {"$not_contains": word}


def binary_operator_clause(
Expand Down
12 changes: 12 additions & 0 deletions chromadb/test/property/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,23 @@ def _filter_where_doc_clause(clause: WhereDocument, doc: Document) -> bool:
# Simple $contains clause
assert isinstance(expr, str)
if key == "$contains":
if not doc:
return False
# SQLite FTS handles % and _ as word boundaries that are ignored so we need to
# treat them as wildcards
if "%" in expr or "_" in expr:
expr = expr.replace("%", ".").replace("_", ".")
return re.search(expr, doc) is not None
return expr in doc
elif key == "$not_contains":
if not doc:
return False
# SQLite FTS handles % and _ as word boundaries that are ignored so we need to
# treat them as wildcards
if "%" in expr or "_" in expr:
expr = expr.replace("%", ".").replace("_", ".")
return re.search(expr, doc) is None
return expr not in doc
else:
raise ValueError("Unknown operator: {}".format(key))

Expand All @@ -118,6 +129,7 @@ def _filter_embedding_set(
ids = set(normalized_record_set["ids"])

filter_ids = filter["ids"]

if filter_ids is not None:
filter_ids = invariants.wrap(filter_ids)
assert filter_ids is not None
Expand Down
34 changes: 34 additions & 0 deletions chromadb/test/segment/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,40 @@ def test_fulltext(
result = segment.get_metadata(where_document={"$contains": "four two"})
assert len(result) == 1

# Test not_contains
result = segment.get_metadata(where_document={"$not_contains": "four two"})
assert len(result) == len(
[i for i in range(1, 100) if "four two" not in _build_document(i)]
)

# Test many results
result = segment.get_metadata(where_document={"$contains": "zero"})
assert len(result) == 9

# Test not_contains
result = segment.get_metadata(where_document={"$not_contains": "zero"})
assert len(result) == len(
[i for i in range(1, 100) if "zero" not in _build_document(i)]
)

# test $and
result = segment.get_metadata(
where_document={"$and": [{"$contains": "four"}, {"$contains": "two"}]}
)
assert len(result) == 2
assert set([r["id"] for r in result]) == {"embedding_42", "embedding_24"}

result = segment.get_metadata(
where_document={"$and": [{"$not_contains": "four"}, {"$not_contains": "two"}]}
)
assert len(result) == len(
[
i
for i in range(1, 100)
if "four" not in _build_document(i) and "two" not in _build_document(i)
]
)

# test $or
result = segment.get_metadata(
where_document={"$or": [{"$contains": "zero"}, {"$contains": "one"}]}
Expand All @@ -316,6 +339,17 @@ def test_fulltext(
expected = set([f"embedding_{i}" for i in set(ones + zeros)])
assert set([r["id"] for r in result]) == expected

result = segment.get_metadata(
where_document={"$or": [{"$not_contains": "zero"}, {"$not_contains": "one"}]}
)
assert len(result) == len(
[
i
for i in range(1, 100)
if "zero" not in _build_document(i) or "one" not in _build_document(i)
]
)

# test combo with where clause (negative case)
result = segment.get_metadata(
where={"int_key": {"$eq": 42}}, where_document={"$contains": "zero"}
Expand Down
4 changes: 3 additions & 1 deletion chromadb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ class VectorQueryResult(TypedDict):
Union[str, LogicalOperator], Union[LiteralValue, OperatorExpression, List["Where"]]
]

WhereDocumentOperator = Union[Literal["$contains"], LogicalOperator]
WhereDocumentOperator = Union[
Literal["$contains"], Literal["$not_contains"], LogicalOperator
]
WhereDocument = Dict[WhereDocumentOperator, Union[str, List["WhereDocument"]]]


Expand Down
Loading

0 comments on commit 88db738

Please sign in to comment.